You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
89 lines
2.8 KiB
89 lines
2.8 KiB
import psycopg2
|
|
from psycopg2 import OperationalError
|
|
from psycopg2.extras import RealDictCursor
|
|
import json
|
|
from datetime import date, datetime
|
|
|
|
|
|
class PostgreSQLUtil:
|
|
def __init__(self, host="10.10.14.71", port=5432,
|
|
dbname="szjz_db", user="postgres", password="DsideaL147258369"):
|
|
self.conn_params = {
|
|
"host": host,
|
|
"port": port,
|
|
"dbname": dbname,
|
|
"user": user,
|
|
"password": password
|
|
}
|
|
self.connection = None
|
|
|
|
def __enter__(self):
|
|
self.connect()
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
self.close()
|
|
|
|
def connect(self):
|
|
try:
|
|
self.connection = psycopg2.connect(**self.conn_params)
|
|
print("成功连接到PostgreSQL数据库")
|
|
except OperationalError as e:
|
|
print(f"连接错误: {e}")
|
|
raise
|
|
|
|
def close(self):
|
|
if self.connection:
|
|
self.connection.close()
|
|
print("数据库连接已关闭")
|
|
|
|
def execute_query(self, sql, params=None, return_dict=True):
|
|
"""执行查询并返回结果"""
|
|
try:
|
|
with self.connection.cursor(
|
|
cursor_factory=RealDictCursor if return_dict else None
|
|
) as cursor:
|
|
cursor.execute(sql, params)
|
|
|
|
if cursor.description:
|
|
columns = [desc[0] for desc in cursor.description]
|
|
results = cursor.fetchall()
|
|
|
|
# 转换字典格式
|
|
if return_dict:
|
|
return results
|
|
else:
|
|
return [dict(zip(columns, row)) for row in results]
|
|
else:
|
|
return {"rowcount": cursor.rowcount}
|
|
|
|
except Exception as e:
|
|
print(f"执行SQL出错: {e}")
|
|
self.connection.rollback()
|
|
raise
|
|
finally:
|
|
self.connection.commit()
|
|
|
|
def query_to_json(self, sql, params=None):
|
|
"""返回JSON格式结果"""
|
|
data = self.execute_query(sql, params)
|
|
return json.dumps(data, default=self.json_serializer)
|
|
|
|
@staticmethod
|
|
def json_serializer(obj):
|
|
"""处理JSON无法序列化的类型"""
|
|
if isinstance(obj, (date, datetime)):
|
|
return obj.isoformat()
|
|
raise TypeError(f"Type {type(obj)} not serializable")
|
|
|
|
|
|
# 使用示例
|
|
if __name__ == "__main__":
|
|
with PostgreSQLUtil() as db:
|
|
# 示例查询
|
|
result = db.execute_query("SELECT version()")
|
|
print("数据库版本:", result)
|
|
|
|
# 返回JSON
|
|
json_data = db.query_to_json("SELECT * FROM t_base_class LIMIT 2")
|
|
print("JSON结果:", json_data) |