"""
Implement some datastore based on SQL databases
"""
import sqlite3
from werkzeug.exceptions import NotFound, BadRequest
from .base import DataStore
from rest_api_framework.models import PkField
[docs]class SQLiteDataStore(DataStore):
"""
Define a sqlite datastore for your ressource. you have to give
__init__ a data parameter containing the information to connect to
the database and to the table.
example:
.. code-block:: python
data={"table": "tweets",
"name": "test.db"}
model = ApiModel
datastore = SQLiteDataStore(data, **options)
SQLiteDataStore implement a naive wrapper to convert Field
types into database type.
* int will be saved in the database as INTEGER
* float will be saved in the database as REAL
* basestring will be saved in the database as TEXT
* if the Field type is PKField, is a will be saved as
PRIMARY KEY AUTOINCREMENT
As soon as the datastore is instanciated, the database is create
if it does not exists and table is created too
.. note::
- It is not possible to use :memory database either.
The connection is closed after each operations
"""
wrapper = {"integer": "integer",
"float": "real",
"string": "text"
}
def __init__(self, ressource_config, model, **options):
self.db = ressource_config["name"]
self.conn = sqlite3.connect(ressource_config["name"],
check_same_thread=False)
cursor = self.conn.cursor()
table = ressource_config["table"]
super(SQLiteDataStore, self).__init__(
{"conn": self.conn, "table": table},
model,
**options)
self.create_database(cursor, table)
self.conn.commit()
# conn.close()
self.fields = self.model.get_fields()
def create_database(self, cursor, table):
statement = []
for field in self.model.get_fields():
query = "{0} {1}".format(field.name, self.wrapper[field.base_type])
if isinstance(field, PkField):
query += " primary key"
if field.base_type == "integer":
query += " autoincrement"
statement.append(query)
if "required" in field.options\
and field.options['required'] is True:
query += " NOT NULL"
fields = ", ".join(statement)
for field in self.model.get_fields():
if "foreign" in field.options:
fields += ",FOREIGN KEY ({0}) REFERENCES {1}({2})".format(
field.name, field.options["foreign"]["table"],
field.options["foreign"]["column"]
)
sql = 'create table if not exists {0} ({1})'.format(table, fields)
cursor.execute(sql)
[docs] def get_connector(self):
"""
return a sqlite3 connection to communicate with the table
define in self.db
"""
self.conn.execute('pragma foreign_keys=on')
return self.conn
[docs] def filter(self, **kwargs):
"""
Change kwargs["query"] with "WHERE X=Y statements". The
filtering will be done with the actual evaluation of the query
in :meth:`~.SQLiteDataStore.paginate` the sql can then be lazy
"""
kwargs['query'] += ' FROM {0}'
return kwargs
def count(self, **data):
cdt = self.build_conditions(data)
if len(cdt) == 0:
query = "SELECT COUNT (*) FROM {0}".format(
self.ressource_config['table'])
else:
cdt = " AND ".join(cdt)
query = "SELECT COUNT (*) FROM {0} WHERE {1}".format(
self.ressource_config['table'],
cdt
)
cursor = self.get_connector().cursor()
cursor.execute(query)
return cursor.fetchone()[0]
def build_conditions(self, data):
return [
["{0}='{1}'".format(
e[0], e[1]) for e in condition.iteritems()
][0] for condition in self.get_conditions(data)]
def get_conditions(self, data):
rm = []
for elem in data:
if elem not in ['query', 'fields']:
if elem not in self.model.get_fields_name():
rm.append(elem)
for elem in rm:
data.pop(elem)
return [
{k: v} for k, v in data.iteritems() if k not in ["query", "fields"]
]
[docs] def paginate(self, data, **kwargs):
"""
paginate the result of filter using ids limits. Obviously, to
work properly, you have to set the start to the last ids you
receive from the last call on this method. The max number of
row this method can give back depend on the paginate_by option.
"""
where_query = self.build_conditions(data)
args = []
limit = kwargs.pop("end", None)
if kwargs.get("start", None):
where_query.append(" id >=?")
args.append(kwargs.pop('start'))
if len(where_query) > 0:
data["query"] += " WHERE "
data["query"] += " AND ".join(where_query)
cursor = self.get_connector().cursor()
# a hook for ordering
data["query"] += " ORDER BY id ASC"
if limit:
data["query"] += " LIMIT {0}".format(limit)
cursor.execute(data["query"].format(self.ressource_config['table']),
tuple(args)
)
objs = []
for elem in cursor.fetchall():
objs.append(dict(zip(self.fields, elem)))
return objs
def get_fields(self, **fields):
if self.partial:
fields, kwargs = self.partial.get_partials(**fields)
if not fields:
fields = self.model.get_fields_name()
for field in fields:
if field not in self.model.get_fields_name():
raise BadRequest()
if self.model.pk_field.name not in fields:
fields.append(self.model.pk_field.name)
else:
fields = self.model.get_fields_name()
return fields
[docs] def get_list(self, **kwargs):
"""
return all the objects, paginated if needed, fitered if
filters have been set.
"""
self.fields = self.get_fields(**kwargs)
fields = ", ".join(self.fields)
kwargs["query"] = 'SELECT {0}'.format(fields)
start = kwargs.pop("offset", None)
end = kwargs.pop("count", None)
data = self.filter(**kwargs)
return self.paginate(data, start=start, end=end)
[docs] def get(self, identifier):
"""
Return a single row or raise NotFound
"""
fields = ",".join(self.model.get_fields_name())
query = "select {0} from {1} where {2}=?".format(
fields,
self.ressource_config["table"],
self.model.pk_field.name)
cursor = self.get_connector().cursor()
cursor.execute(query, (identifier,))
obj = cursor.fetchone()
if obj:
fields = self.model.get_fields_name()
return dict(zip(fields, obj))
else:
raise NotFound
[docs] def create(self, data):
"""
Validate the data with :meth:`.base.DataStore.validate`
And, if data is valid, create the row in database and return it.
"""
self.validate(data)
fields = []
values = []
for k, v in data.iteritems():
if k in self.model.get_fields_name():
fields.append(str(k))
values.append(unicode(v))
conn = self.conn
cursor = conn.cursor()
query = "insert into {0} {1} values ({2})".format(
self.ressource_config["table"],
tuple(fields),
",".join(["?" for step in range(len(fields))])
)
cursor.execute(query, tuple(values))
self.conn.commit()
return cursor.lastrowid
[docs] def update(self, obj, data):
"""
Retreive the object to be updated
(:meth:`~.SQLiteDataStore.get` will raise a NotFound error if the row
does not exist)
Validate the fields to be updated and return the updated row
"""
self.get(obj[self.model.pk_field.name])
self.validate_fields(data)
fields = []
values = []
for k, v in data.iteritems():
if k in self.model.get_fields_name():
fields.append(k)
values.append(v)
conn = self.conn
cursor = conn.cursor()
update = " ,".join(["{0}='{1}'".format(f, v) for f, v in zip(fields,
values)])
query = "update {0} set {1} WHERE {2}={3}".format(
self.ressource_config["table"],
update,
self.model.pk_field.name,
obj[self.model.pk_field.name]
)
cursor.execute(query)
conn.commit()
return self.get(obj[self.model.pk_field.name])
[docs] def delete(self, identifier):
"""
Retreive the object to be updated
(:meth:`~.SQLiteDataStore.get` will raise a NotFound error if
the row does not exist)
Return None on success, Raise a 400 error if foreign key
constrain prevent delete.
"""
self.get(identifier)
conn = self.conn
cursor = conn.cursor()
query = "delete from {0} where {2}={1}".format(
self.ressource_config["table"],
identifier,
self.model.pk_field.name)
try:
cursor.execute(query)
except sqlite3.IntegrityError, e:
message = ""
if "foreign" in e.message:
message = """another ressource depends on this
object. Cloud not delete before all ressources
depending on it are also deleted"""
raise BadRequest(message)
conn.commit()
def __del__(self):
self.conn.close()