summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--db.py28
2 files changed, 21 insertions, 8 deletions
diff --git a/.gitignore b/.gitignore
index 386c6dd..77b2ff9 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,3 +2,4 @@
*.pyc
/config
/cache
+/*.db
diff --git a/db.py b/db.py
index 00ce023..ebb5701 100644
--- a/db.py
+++ b/db.py
@@ -6,12 +6,27 @@ engine = create_engine(config.get('db_path'))
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
-from sqlalchemy import Column, Integer, String, ForeignKey
+from sqlalchemy import Column, Integer, Unicode, ForeignKey
from sqlalchemy.orm import relationship, backref
from sqlalchemy import and_, or_
from sqlalchemy.orm.exc import NoResultFound
+from sqlalchemy.types import TypeDecorator, Unicode
+
+class String(TypeDecorator):
+ impl = Unicode
+
+ def process_bind_param(self, value, dialect):
+ if isinstance(value, str):
+ value = value.decode('utf-8')
+ return value
+
+ def process_result_value(self, value, dialect):
+ if isinstance(value, unicode):
+ value = value.encode('utf-8')
+ return value
+
class Directory(Base):
__tablename__ = 'directories'
@@ -26,7 +41,7 @@ class Directory(Base):
self.parent_id = parent_id
def __repr__(self):
- return '<Directory("{0}")>'.format(self.path.encode('utf-8'))
+ return '<Directory("{0}")>'.format(self.path)
@staticmethod
def get(session, path, parent_id = None):
@@ -65,7 +80,7 @@ class Artist(Base):
self.name = name
def __repr__(self):
- return '<Artist("{0}")>'.format(self.name.encode('utf-8'))
+ return '<Artist("{0}")>'.format(self.name)
@staticmethod
def get(session, name):
@@ -129,7 +144,7 @@ class Track(Base):
self.album_id = album_id
def __repr__(self):
- return '<Track("{0}")>'.format(self.filename.encode('utf-8'))
+ return '<Track("{0}")>'.format(self.filename)
@staticmethod
def get(session, name, num, filename, file_index, directory_id, artist_id, album_id):
@@ -165,10 +180,7 @@ class Track(Base):
return r.all()
def get_path(self):
- s = os.path.join(self.directory.path, self.filename)
- if isinstance(s, unicode):
- s = s.encode('utf-8')
- return s
+ return os.path.join(self.directory.path, self.filename)
def get_relpath(self):
return os.path.relpath(self.get_path(), config.get('music_root'))