From ea038cab9ff360191bc0a41efa6708d4a2a57092 Mon Sep 17 00:00:00 2001 From: htylight Date: Thu, 5 Oct 2023 12:16:20 +0800 Subject: [PATCH] fig bugs that can't push unreceived msgs when client pull unreceived msgs --- src/crud/contact_crud.py | 23 ++++--- src/crud/group_chat_crud.py | 2 +- src/crud/unreceived_msg_crud.py | 14 ++-- src/database/db.py | 8 ++- src/database/models.py | 2 +- src/response_models/group_chat_response.py | 2 +- src/response_models/message_response.py | 14 ++-- src/routers/contact.py | 7 +- src/routers/group_chat.py | 6 +- src/routers/message.py | 76 ++++++++++++---------- src/utils/static_file.py | 11 ++-- src/utils/web_socket.py | 29 ++++----- 12 files changed, 107 insertions(+), 87 deletions(-) diff --git a/src/crud/contact_crud.py b/src/crud/contact_crud.py index 1bdf7d0..cf945fd 100644 --- a/src/crud/contact_crud.py +++ b/src/crud/contact_crud.py @@ -60,20 +60,25 @@ async def select_friends_group_chats( group_chat_ids: list[str], ) -> Tuple[list[Tuple[UserAccount, UserProfile]], list[GroupChat]]: session = async_session() + friends_res: Result[list[Tuple[UserAccount, UserProfile]]] = await session.execute( select(UserAccount, UserProfile) .join(UserAccount.profile) .where(UserAccount.id.in_(friend_ids)) ) - if group_chat_ids: - group_chats_res = await session.scalars( - select(GroupChat).where(GroupChat.id.in_(group_chat_ids)) - ) - await session.close() - return friends_res.all(), group_chats_res.all() - else: - await session.close() - return friends_res.all(), [] + + group_chats_res = await session.scalars( + select(GroupChat).where(GroupChat.id.in_(group_chat_ids)) + ) + + try: + group_chats = group_chats_res.all() + except Exception as e: + print(f"You now have not been in any group chat: {e}") + group_chats = [] + + await session.close() + return friends_res.all(), group_chats async def update_friend_setting( diff --git a/src/crud/group_chat_crud.py b/src/crud/group_chat_crud.py index 1779bf5..dd99b63 100644 --- a/src/crud/group_chat_crud.py +++ b/src/crud/group_chat_crud.py @@ -12,7 +12,7 @@ pool = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"] def _create_random_id() -> str: random_chars = random.choices(pool, k=11) - return "".join(random_chars) + return "To" + "".join(random_chars) async def insert_group_chat(supervisor: str, members: list[str]) -> GroupChat: diff --git a/src/crud/unreceived_msg_crud.py b/src/crud/unreceived_msg_crud.py index b5a2f05..aa51a38 100644 --- a/src/crud/unreceived_msg_crud.py +++ b/src/crud/unreceived_msg_crud.py @@ -1,4 +1,4 @@ -from sqlalchemy import select, insert, delete, ScalarResult +from sqlalchemy import select, insert, delete, ScalarResult, CursorResult from sqlalchemy.orm import Session from sqlalchemy.exc import ResourceClosedError @@ -69,18 +69,20 @@ def select_msgs(receiver_id: str) -> list[UnreceivedMsg]: session.close() -async def delete_and_return_msgs(receiver_id: str) -> list[UnreceivedMsg]: +async def delete_and_return_msgs(receiver_id: str) -> list[()]: session = async_session() try: - res: ScalarResult[UnreceivedMsg] = await session.execute( + res: CursorResult[UnreceivedMsg] = await session.execute( delete(UnreceivedMsg) .where(UnreceivedMsg.receiver_id == receiver_id) - .returning() + .returning('*') ) + msgs = res.all() + print(msgs) await session.commit() await session.close() - return list(msgs) + return msgs except ResourceClosedError as e1: print(e1) raise e1 @@ -99,6 +101,6 @@ def delete_msgs(receiver_id: str): ) session.commit() except Exception as e: - print(e) + print(f"Deleting Unreceived msgs fail with error: {e}") finally: session.close() diff --git a/src/database/db.py b/src/database/db.py index 044ea27..cae9fa9 100755 --- a/src/database/db.py +++ b/src/database/db.py @@ -4,6 +4,10 @@ from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine db_url = "postgresql+psycopg2://together:togetherno.1@localhost/together" async_db_url = "postgresql+asyncpg://together:togetherno.1@localhost/together" -engine = create_engine(db_url, echo=True) -async_engine = create_async_engine(async_db_url, echo=True) +engine = create_engine( + db_url, +) +async_engine = create_async_engine( + async_db_url, +) async_session = async_sessionmaker(async_engine, expire_on_commit=False) diff --git a/src/database/models.py b/src/database/models.py index 20db2c2..452ab5d 100755 --- a/src/database/models.py +++ b/src/database/models.py @@ -138,7 +138,7 @@ class Apply(Base): class GroupChat(Base): __tablename__ = "group_chat" - id: Mapped[str] = mapped_column(String(11), primary_key=True) + id: Mapped[str] = mapped_column(String(13), primary_key=True) name: Mapped[str] = mapped_column(String) supervisor: Mapped[str] = mapped_column(String(26)) administrators: Mapped[list[str]] = mapped_column(ARRAY(String(26)), default=[]) diff --git a/src/response_models/group_chat_response.py b/src/response_models/group_chat_response.py index c858b69..bc01cd9 100644 --- a/src/response_models/group_chat_response.py +++ b/src/response_models/group_chat_response.py @@ -16,7 +16,7 @@ class _GroupChatProfile(BaseModel): class _MemberNameAvatar(BaseModel): - remark: str + remarkInGroupChat: str nickname: str avatar: str diff --git a/src/response_models/message_response.py b/src/response_models/message_response.py index 0fdf13a..ce3d571 100644 --- a/src/response_models/message_response.py +++ b/src/response_models/message_response.py @@ -5,17 +5,17 @@ class _UnreceivedMsg(BaseModel): msgId: str event: str type: str - receiver_id: str - sender_id: str - group_chat_id: str + receiverId: str + senderId: str + groupChatId: str | None nickname: str | None remarkInGroupChat: str | None avatar: str | None text: str - attachments: list[str] - date_time: str - is_show_time: bool + attachments: list[str | None] + dateTime: str + isShowTime: bool class UnreceivedMsgResponse(BaseResponseModel): - data: _UnreceivedMsg | None = None + data: list[_UnreceivedMsg] | None = None diff --git a/src/routers/contact.py b/src/routers/contact.py index f47379d..b6f5fca 100755 --- a/src/routers/contact.py +++ b/src/routers/contact.py @@ -14,7 +14,7 @@ router = APIRouter(prefix="/contact", tags=["contact"]) class ContactIds(BaseModel): friend_ids: list[str] - group_chat_ids: list[str] | None = None + group_chat_ids: list[str] class ChangeFriendSetting(BaseModel): @@ -55,9 +55,8 @@ async def get_contact_account_profiles(contact_ids: ContactIds): friends_account_profiles[account.id] = account.to_dict() friends_account_profiles[account.id].update(profile.to_dict()) - if group_chats_res: - for group_chat in group_chats_res: - group_chats_profiles[group_chat.id] = group_chat.to_dict() + for group_chat in group_chats_res: + group_chats_profiles[group_chat.id] = group_chat.to_dict() return { "code": 10700, diff --git a/src/routers/group_chat.py b/src/routers/group_chat.py index 5b6eb3a..ed8f193 100644 --- a/src/routers/group_chat.py +++ b/src/routers/group_chat.py @@ -59,13 +59,13 @@ async def get_member_name_avatar(group_chat_id: str, member_id: str, is_friend: if is_friend: if res.get(group_chat_id): # make sure my friend is still in this group chat - data["remark"] = res[group_chat_id]["myRemark"] + data["remarkInGroupChat"] = res[group_chat_id]["remarkInGroupChat"] data["nickname"] = "" data["avatar"] = "" else: if len(res) == 3: # make sure this user is still in this group chat - data["remark"] = res[2][group_chat_id]["myRemark"] + data["remarkInGroupChat"] = res[2][group_chat_id]["remarkInGroupChat"] data["nickname"] = res[0] data["avatar"] = res[1] or "" @@ -88,7 +88,7 @@ async def get_full_profile(group_chat_id: str): for member_name_avatar in member_name_avatar_list: member_name_avatar_dict[member_name_avatar[0]] = { - "remark": member_name_avatar[3], + "remarkInGroupChat": member_name_avatar[3], "nickname": member_name_avatar[1], "avatar": member_name_avatar[2] or "", } diff --git a/src/routers/message.py b/src/routers/message.py index ce15bb9..337e392 100644 --- a/src/routers/message.py +++ b/src/routers/message.py @@ -15,6 +15,7 @@ router = APIRouter(tags=["message"]) ws_manager = WebSocketManager() + loop = asyncio.new_event_loop() @@ -31,21 +32,21 @@ async def push_unsent_messages(): await ws_manager.send_to_another(user_id, msg.to_dict()) for attachment in msg.attachments: for ( - current_chunk_num, - total_chunk_num, - byte_array, + current_chunk_num, + total_chunk_num, + byte_array, ) in read_chat_file(attachment): - await ws_manager.send_to_another( - user_id, + await ws_manager.active_socket[user_id].send_json( { "event": "chat-image", "filename": attachment, - "tempFilename": f"temp/{attachment}-${total_chunk_num}-${current_chunk_num}", + "tempFilename": f"temp/{attachment}-{total_chunk_num}-{current_chunk_num}", "totalChunkNum": total_chunk_num, "currentChunkNum": current_chunk_num, "bytes": byte_array, - }, + } ) + if msgs: unreceived_msg_crud.delete_msgs(user_id) except Exception as e: @@ -72,10 +73,14 @@ def message_startup_event(): @router.on_event("shutdown") async def message_shutdown_event(): print("关闭所有连接............") - loop.stop() + # loop.stop() ws_manager.disconnect_all() +async def send_unreceived_attachments(receiver_id: str, msg: dict): + await ws_manager.active_socket[receiver_id].send_json(msg) + + @router.websocket("/ws/{user_id}") async def connect_websocket(websocket: WebSocket, user_id: str, is_reconnect: bool): if is_reconnect: @@ -142,11 +147,11 @@ async def connect_websocket(websocket: WebSocket, user_id: str, is_reconnect: bo dependencies=[Depends(verify_token)], ) async def get_unreceived_msgs( - receiver_id: str, - background_tasks: BackgroundTasks, + receiver_id: str, + background_tasks: BackgroundTasks, ): try: - msgs: list[UnreceivedMsg] = await unreceived_msg_crud.delete_and_return_msgs( + msgs: list[()] = await unreceived_msg_crud.delete_and_return_msgs( receiver_id ) msgs.sort(key=lambda msg: msg.msg_id) @@ -154,9 +159,12 @@ async def get_unreceived_msgs( all_msg_attachments = [] for msg in msgs: - if not msg.attachments: - all_msg_attachments.extend(msg.attachments) - json_msgs.append(msg.to_dict()) + if msg[11]: + all_msg_attachments.extend(msg[11]) + json_msgs.append( + {'msgId': msg[1], 'event': msg[2], 'type': msg[3], 'receiverId': msg[4], 'senderId': msg[5], + 'groupChatId': msg[6], 'nickname': msg[7], 'remarkInGroupChat': msg[8], 'avatar': msg[9], + 'text': msg[10], 'attachments': msg[11], 'dateTime': msg[12], 'isShowTime': msg[13]}) if all_msg_attachments: background_tasks.add_task( @@ -175,6 +183,26 @@ async def get_unreceived_msgs( return {"code": 9999, "msg": "Server Error"} +async def send_image_by_websocket(receiver_id: str, attachments: list[str]): + print("send_image_by_websocket") + for attachment in attachments: + async for ( + current_chunk_num, + total_chunk_num, + byte_array, + ) in async_read_chat_file(attachment): + await ws_manager.active_socket[receiver_id].send_json( + { + "event": "chat-image", + "filename": attachment, + "tempFilename": f"temp/{attachment}-{total_chunk_num}-{current_chunk_num}", + "totalChunkNum": total_chunk_num, + "currentChunkNum": current_chunk_num, + "bytes": byte_array, + } + ) + + class UploadAttachment(BaseModel): event: str senderId: str @@ -196,23 +224,3 @@ async def upload_attachment(data: UploadAttachment): await ws_manager.broadcast(data["receiverIds"], data) return {"code": 10900, "msg": "Ok"} - - -async def send_image_by_websocket(receiver_id: str, attachments: list[str]): - for attachment in attachments: - for ( - current_chunk_num, - total_chunk_num, - byte_array, - ) in await async_read_chat_file(attachment): - await ws_manager.send_to_another( - receiver_id, - { - "event": "chat-image", - "filename": attachment, - "tempFilename": f"temp/{attachment}-${total_chunk_num}-${current_chunk_num}", - "totalChunkNum": total_chunk_num, - "currentChunkNum": current_chunk_num, - "bytes": byte_array, - }, - ) diff --git a/src/utils/static_file.py b/src/utils/static_file.py index 44aed22..17aa953 100755 --- a/src/utils/static_file.py +++ b/src/utils/static_file.py @@ -2,7 +2,7 @@ import os import random import array import math -from typing import Literal, Tuple +from typing import Literal, Tuple, AsyncIterable from pathlib import Path from datetime import datetime from zipfile import ZipFile @@ -38,7 +38,7 @@ alphabet = [ "z", ] -CHUNK_SIZE = 1024 * 1024 * 1 +CHUNK_SIZE = 1024 * 1024 * 2 def create_avatar_dir(type: Literal["user", "group_chat"], dir_name: str) -> Path: @@ -122,7 +122,10 @@ async def write_chat_file( ) as chunk_file: bytes_content = await chunk_file.read() await file.write(bytes_content) - os.remove(temp_image_dir / f"{filename}-{total_chunk_num}-{i}") + # os.remove( + # temp_image_dir / f"{filename}-{total_chunk_num}-{i}", + # dir_fd=None, + # ) case "video": pass @@ -150,7 +153,7 @@ def read_chat_file(filename: str) -> Tuple[int, int, list[int]]: file.close() -async def async_read_chat_file(file_path: str) -> Tuple[int, int, list[int]]: +async def async_read_chat_file(file_path: str) -> AsyncIterable[Tuple[int, int, list[int]]]: file_suffix: str = file_path.split(".")[1] if file_suffix == "png": diff --git a/src/utils/web_socket.py b/src/utils/web_socket.py index 3de897f..636e71d 100644 --- a/src/utils/web_socket.py +++ b/src/utils/web_socket.py @@ -47,8 +47,7 @@ class WebSocketManager: if msg["event"] == "chat-image": await write_chat_file(msg, "image") if self.active_socket.get(msg["senderId"]): - ws = self.active_socket.get(msg["senderId"]) - await ws.send_json( + await self.active_socket.get(msg["senderId"]).send_json( { "event": "chat-image-send-ok", "chatType": 0, @@ -85,16 +84,16 @@ class WebSocketManager: remark_in_group_chat=msg["remarkInGroupChat"], avatar=msg["avatar"], ) - if msg["event"] == "chat-image": - await write_chat_file(msg, "image") - if self.active_socket.get(msg["senderId"]): - await self.active_socket["senderId"].send_json( - { - "event": "chat-image-send-ok", - "chatType": 1, - "receiverIds": receiver_ids, - "currentChunkNum": msg["currentChunkNum"], - "totalChunkNum": msg["totalChunkNum"], - "filename": msg["filename"], - } - ) + if msg["event"] == "chat-image": + await write_chat_file(msg, "image") + if self.active_socket.get(msg["senderId"]): + await self.active_socket[msg["senderId"]].send_json( + { + "event": "chat-image-send-ok", + "chatType": 1, + "receiverIds": receiver_ids, + "currentChunkNum": msg["currentChunkNum"], + "totalChunkNum": msg["totalChunkNum"], + "filename": msg["filename"], + } + )