122 lines
2.6 KiB
Python
122 lines
2.6 KiB
Python
from pprint import pprint
|
|
import sqlite3
|
|
|
|
conn = sqlite3.connect( 'babyorm.db' )
|
|
c = conn.cursor()
|
|
|
|
class Model( dict ):
|
|
def __init__(self, **kwargs):
|
|
# c.execute('pragma table_info("{}")'.format(self.__class__.__name__))
|
|
self.columns = self.__columns()
|
|
|
|
for key, value in kwargs.items():
|
|
setattr(self, key, value)
|
|
|
|
def __setitem__(self, key, value):
|
|
if key in self.columns:
|
|
return super().__setitem__(key, value)
|
|
else:
|
|
return False
|
|
|
|
def __setattr__(self, key, value):
|
|
self.__setitem__(key, value)
|
|
|
|
return super().__setattr__(key, value)
|
|
|
|
@staticmethod
|
|
def __parse_args(obj, on_join=" AND "): # needs better name
|
|
return on_join.join("{} = '{}'".format(name, value) for name, value in obj.items())
|
|
|
|
@classmethod
|
|
def __row_to_dict(cls, row):
|
|
return {column:row[index] for index, column in enumerate( cls.__columns() )}
|
|
|
|
@classmethod
|
|
def __columns(cls, force_update=False):
|
|
if force_update or hasattr(cls, 'column'):
|
|
return cls.column
|
|
|
|
c.execute('pragma table_info("{}")'.format(cls.__name__))
|
|
cls.columns = tuple(d[1] for d in c)
|
|
|
|
return cls.columns
|
|
|
|
@classmethod
|
|
def all(cls):
|
|
c.execute( "SELECT * FROM {}".format(cls.__name__) )
|
|
|
|
return [cls(**cls.__row_to_dict(row)) for row in c.fetchall()]
|
|
|
|
@classmethod
|
|
def get(cls, *args, **kwargs):
|
|
kwargs_ = {'id': args[0]} if args else kwargs
|
|
c.execute("SELECT * FROM {} WHERE {} limit 1".format(
|
|
cls.__name__,
|
|
cls.__parse_args(kwargs_)
|
|
))
|
|
|
|
return cls(**cls.__row_to_dict(c.fetchone()))
|
|
|
|
@classmethod
|
|
def filter(cls, **kwargs):
|
|
|
|
c.execute("SELECT * FROM {} WHERE {} ".format(
|
|
cls.__name__,
|
|
cls.__parse_args(kwargs)
|
|
))
|
|
|
|
return [cls(**cls.__row_to_dict(row)) for row in c.fetchall()]
|
|
|
|
def save( self ):
|
|
if 'id' in self:
|
|
self.update()
|
|
else:
|
|
self.create()
|
|
|
|
def create( self ):
|
|
keys = ','.join( [key for key in self.keys()] )
|
|
values = ','.join( ["'{}'".format(value) for value in self.values()] )
|
|
|
|
sql_string = "INSERT INTO {} ({}) VALUES ({})".format(
|
|
self.__class__.__name__,
|
|
keys,
|
|
values
|
|
)
|
|
print(sql_string)
|
|
c.execute(sql_string)
|
|
|
|
setattr(self, 'id', c.lastrowid)
|
|
conn.commit()
|
|
return self
|
|
|
|
def update( self ):
|
|
c.execute("UPDATE {} SET {} WHERE id={}".format(
|
|
self.__class__.__name__,
|
|
self.__parse_args(self, ', '),
|
|
self['id']
|
|
) )
|
|
|
|
conn.commit()
|
|
|
|
###don't touch the code for these
|
|
class Users(Model):
|
|
pass
|
|
|
|
class Stocks(Model):
|
|
pass
|
|
|
|
if __name__ == '__main__':
|
|
pass
|
|
# dan = Users(name='dan')
|
|
# dan.save()
|
|
# dan.email = "dan@gmail.com"
|
|
# dan.save()
|
|
# pprint(dan)
|
|
|
|
# print( 'all' )
|
|
# pprint( Users.all() )
|
|
# print( 'get' )
|
|
# pprint( Users.get(name="Kenny") )
|
|
# print( 'filter' )
|
|
# pprint( Users.filter(id=dan['id']) )
|