From 1c00eae62e27dc3f644dbc2f9c555234d74c939f Mon Sep 17 00:00:00 2001 From: htylight Date: Sun, 1 Oct 2023 18:43:25 +0800 Subject: [PATCH] implement slice transmitting of chat image --- ...c889c1715f_create_unrecieved_msg_table.py} | 16 +-- src/crud/group_chat_crud.py | 17 +-- src/crud/unreceived_msg_crud.py | 6 +- src/database/json_typeddict.py | 4 +- src/database/models.py | 9 +- src/response_models/message_response.py | 21 ++++ src/routers/message.py | 114 +++++++++++++++--- src/utils/static_file.py | 87 +++++++++++-- src/utils/web_socket.py | 43 +++++-- 9 files changed, 258 insertions(+), 59 deletions(-) rename migrations/versions/{ef4cbdcc711b_create_unreceived_msg_table.py => 5ac889c1715f_create_unrecieved_msg_table.py} (74%) create mode 100644 src/response_models/message_response.py diff --git a/migrations/versions/ef4cbdcc711b_create_unreceived_msg_table.py b/migrations/versions/5ac889c1715f_create_unrecieved_msg_table.py similarity index 74% rename from migrations/versions/ef4cbdcc711b_create_unreceived_msg_table.py rename to migrations/versions/5ac889c1715f_create_unrecieved_msg_table.py index 1b0737e..162da6d 100644 --- a/migrations/versions/ef4cbdcc711b_create_unreceived_msg_table.py +++ b/migrations/versions/5ac889c1715f_create_unrecieved_msg_table.py @@ -1,8 +1,8 @@ -"""create unreceived_msg table +"""create unrecieved_msg table -Revision ID: ef4cbdcc711b +Revision ID: 5ac889c1715f Revises: 4947792c7572 -Create Date: 2023-09-16 10:47:50.809077 +Create Date: 2023-09-23 19:48:51.530634 """ from alembic import op @@ -10,7 +10,7 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. -revision = 'ef4cbdcc711b' +revision = '5ac889c1715f' down_revision = '4947792c7572' branch_labels = None depends_on = None @@ -19,19 +19,21 @@ depends_on = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.create_table('unreceived_msg', - sa.Column('msg_id', sa.String(length=16), nullable=False), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('msg_id', sa.String(length=18), nullable=False), sa.Column('event', sa.String(), nullable=False), sa.Column('type', sa.String(), nullable=False), sa.Column('receiver_id', sa.String(length=26), nullable=False), sa.Column('sender_id', sa.String(length=26), nullable=False), sa.Column('group_chat_id', sa.String(length=11), nullable=True), - sa.Column('name', sa.String(), nullable=True), + sa.Column('remark_in_group_chat', sa.String(), nullable=True), + sa.Column('nickname', sa.String(), nullable=True), sa.Column('avatar', sa.String(), nullable=True), sa.Column('text', sa.String(), nullable=False), sa.Column('attachments', sa.ARRAY(sa.String()), nullable=False), sa.Column('date_time', sa.String(), nullable=False), sa.Column('is_show_time', sa.Boolean(), nullable=False), - sa.PrimaryKeyConstraint('msg_id') + sa.PrimaryKeyConstraint('id') ) # ### end Alembic commands ### diff --git a/src/crud/group_chat_crud.py b/src/crud/group_chat_crud.py index d3676e5..1779bf5 100644 --- a/src/crud/group_chat_crud.py +++ b/src/crud/group_chat_crud.py @@ -29,7 +29,7 @@ async def insert_group_chat(supervisor: str, members: list[str]) -> GroupChat: select(Contact).where(Contact.user_id.in_(members)) ) for contact in contact_res.all(): - contact.group_chats[id] = {"nameRemark": "", "myRemark": ""} + contact.group_chats[id] = {"groupChatRemark": "", "remarkInGroupChat": ""} flag_modified(contact, "group_chats") session.add_all(contact_res.all()) @@ -57,7 +57,10 @@ async def insert_group_chat_members(group_chat_id: str, members: list[str]): contact: Contact = ( await session.scalars(select(Contact).where(Contact.user_id == member)) ).one() - contact.group_chats[group_chat_id] = {"myRemark": "", "nameRemark": ""} + contact.group_chats[group_chat_id] = { + "remarkInGroupChat": "", + "groupChatRemark": "", + } flag_modified(contact, "group_chats") session.add(contact) @@ -75,7 +78,7 @@ async def select_member_name_avatar(member_id: str, is_friend: bool) -> dict | T select(Contact.group_chats).where(Contact.user_id == member_id) ) await session.close() - # {'81906574618': {'myRemark': '', 'nameRemark': ''}} + # {'81906574618': {'remarkInGroupChat': '', 'groupChatRemark': ''}} return res.one() else: res = await session.execute( @@ -84,7 +87,7 @@ async def select_member_name_avatar(member_id: str, is_friend: bool) -> dict | T .where(UserProfile.user_id == member_id) ) await session.close() - # ('htylight', 'cznowoyn1692502503.png', {'81906574618': {'myRemark': '', 'nameRemark': ''},}) + # ('htylight', 'cznowoyn1692502503.png', {'81906574618': {'remarkInGroupChat': '', 'groupChatRemark': ''},}) return res.one() @@ -99,7 +102,7 @@ async def select_full_profile(group_chat_id: str) -> Tuple[GroupChat, list[Tuple UserProfile.user_id, UserProfile.nickname, UserProfile.avatar, - Contact.group_chats[group_chat_id]["myRemark"], + Contact.group_chats[group_chat_id]["remarkInGroupChat"], ) .join(Contact, UserProfile.user_id == Contact.user_id) .where(UserProfile.user_id.in_(members)) @@ -147,7 +150,7 @@ async def update_group_remark(user_id: str, group_chat_id: str, new_remark: str) contact: Contact = ( await session.scalars(select(Contact).where(Contact.user_id == user_id)) ).one() - contact.group_chats[group_chat_id]["nameRemark"] = new_remark + contact.group_chats[group_chat_id]["groupChatRemark"] = new_remark flag_modified(contact, "group_chats") session.add(contact) await session.commit() @@ -163,7 +166,7 @@ async def update_my_remark(user_id: str, group_chat_id: str, new_my_remark: str) contact: Contact = ( await session.scalars(select(Contact).where(Contact.user_id == user_id)) ).one() - contact.group_chats[group_chat_id]["myRemark"] = new_my_remark + contact.group_chats[group_chat_id]["remarkInGroupChat"] = new_my_remark flag_modified(contact, "group_chats") session.add(contact) await session.commit() diff --git a/src/crud/unreceived_msg_crud.py b/src/crud/unreceived_msg_crud.py index 9ddbef6..b5a2f05 100644 --- a/src/crud/unreceived_msg_crud.py +++ b/src/crud/unreceived_msg_crud.py @@ -17,7 +17,8 @@ async def insert_unreceived_msg( date_time: str, is_show_time, group_chat_id: str = None, - name: str = None, + nickname: str = None, + remark_in_group_chat: str = None, avatar: str = None, ): session = async_session() @@ -30,7 +31,8 @@ async def insert_unreceived_msg( sender_id=sender_id, receiver_id=receiver_id, group_chat_id=group_chat_id, - name=name, + nickname=nickname, + remark_in_group_chat=remark_in_group_chat, avatar=avatar, text=text, attachments=attachments, diff --git a/src/database/json_typeddict.py b/src/database/json_typeddict.py index ce9af44..36373be 100644 --- a/src/database/json_typeddict.py +++ b/src/database/json_typeddict.py @@ -7,5 +7,5 @@ class FriendSetting(TypedDict): class GroupChatSetting(TypedDict): - nameRemark: str | None - myRemark: str | None + groupChatRemark: str | None + remarkInGroupChat: str | None diff --git a/src/database/models.py b/src/database/models.py index 6a57468..20db2c2 100755 --- a/src/database/models.py +++ b/src/database/models.py @@ -168,13 +168,15 @@ class GroupChat(Base): class UnreceivedMsg(Base): __tablename__ = "unreceived_msg" - msg_id: Mapped[str] = mapped_column(String(16), primary_key=True) + id: Mapped[int] = mapped_column(Integer, primary_key=True) + msg_id: Mapped[str] = mapped_column(String(18)) event: Mapped[str] = mapped_column(String) type: Mapped[str] = mapped_column(String) receiver_id: Mapped[str] = mapped_column(String(26)) sender_id: Mapped[str] = mapped_column(String(26)) group_chat_id: Mapped[str] = mapped_column(String(11), nullable=True) - name: Mapped[str] = mapped_column(String, nullable=True) + remark_in_group_chat: Mapped[str] = mapped_column(String, nullable=True) + nickname: Mapped[str] = mapped_column(String, nullable=True) avatar: Mapped[str] = mapped_column(String, nullable=True) text: Mapped[str] = mapped_column(String) attachments: Mapped[list[str]] = mapped_column(ARRAY(String)) @@ -188,6 +190,9 @@ class UnreceivedMsg(Base): "type": self.type, "senderId": self.sender_id, "groupChatId": self.group_chat_id, + "nickname": self.nickname, + "remarkInGroupChat": self.remark_in_group_chat, + "avatar": self.avatar, "text": self.text, "attachments": self.attachments, "dateTime": self.date_time, diff --git a/src/response_models/message_response.py b/src/response_models/message_response.py new file mode 100644 index 0000000..0fdf13a --- /dev/null +++ b/src/response_models/message_response.py @@ -0,0 +1,21 @@ +from .base import BaseModel, BaseResponseModel + + +class _UnreceivedMsg(BaseModel): + msgId: str + event: str + type: str + receiver_id: str + sender_id: str + group_chat_id: str + nickname: str | None + remarkInGroupChat: str | None + avatar: str | None + text: str + attachments: list[str] + date_time: str + is_show_time: bool + + +class UnreceivedMsgResponse(BaseResponseModel): + data: _UnreceivedMsg | None = None diff --git a/src/routers/message.py b/src/routers/message.py index 3195e4c..ce15bb9 100644 --- a/src/routers/message.py +++ b/src/routers/message.py @@ -1,12 +1,15 @@ import threading import asyncio -from fastapi import APIRouter, WebSocket, WebSocketDisconnect, BackgroundTasks +from fastapi import APIRouter, WebSocket, WebSocketDisconnect, BackgroundTasks, Depends +from pydantic import BaseModel +from ..dependencies import verify_token from ..utils.web_socket import WebSocketManager -from ..utils.static_file import read_chat_file +from ..utils.static_file import read_chat_file, async_read_chat_file from ..crud import unreceived_msg_crud from ..database.models import UnreceivedMsg +from ..response_models.message_response import UnreceivedMsgResponse router = APIRouter(tags=["message"]) @@ -26,14 +29,20 @@ async def push_unsent_messages(): msgs: list[UnreceivedMsg] = unreceived_msg_crud.select_msgs(user_id) for msg in msgs: await ws_manager.send_to_another(user_id, msg.to_dict()) - if msg.attachments: - for attachment in msg.attachments: - byte_array = read_chat_file(attachment) + for attachment in msg.attachments: + for ( + current_chunk_num, + total_chunk_num, + byte_array, + ) in read_chat_file(attachment): await ws_manager.send_to_another( user_id, { "event": "chat-image", "filename": attachment, + "tempFilename": f"temp/{attachment}-${total_chunk_num}-${current_chunk_num}", + "totalChunkNum": total_chunk_num, + "currentChunkNum": current_chunk_num, "bytes": byte_array, }, ) @@ -76,11 +85,22 @@ async def connect_websocket(websocket: WebSocket, user_id: str, is_reconnect: bo try: while True: data = await websocket.receive_json() + match data["event"]: case "ping": await ws_manager.active_socket[user_id].send_json({"type": "pong"}) case "friend-chat-msg": await ws_manager.send_to_another(data["receiverId"], data) + if len(data["attachments"]) > 0: + await ws_manager.send_to_another( + data["senderId"], + { + "event": "pull-chat-image", + "chatType": 0, + "attachments": data["attachments"], + "receiverId": data["receiverId"], + }, + ) case "apply-friend": await ws_manager.send_to_another(data["recipient"], data) case "friend-added" | "friend-deleted": @@ -91,18 +111,36 @@ async def connect_websocket(websocket: WebSocket, user_id: str, is_reconnect: bo else: receiver_ids = data["receiverIds"] del data["receiverIds"] - await ws_manager.broadcast(data, receiver_ids) - case "group-chat-creation" | "group-chat-msg": + await ws_manager.broadcast(receiver_ids, data) + case "group-chat-creation": receiver_ids = data["receiverIds"] del data["receiverIds"] - await ws_manager.broadcast(data, receiver_ids) + await ws_manager.broadcast(receiver_ids, data) + case "group-chat-msg": + receiver_ids = data["receiverIds"] + del data["receiverIds"] + await ws_manager.broadcast(receiver_ids, data) + if len(data["attachments"]) > 0: + await ws_manager.send_to_another( + data["senderId"], + { + "event": "pull-chat-image", + "chatType": 1, + "attachments": data["attachments"], + "receiverIds": receiver_ids, + }, + ) except WebSocketDisconnect: print(f"{user_id} disconnect") ws_manager.disconnect(user_id) -@router.get("/message/unreceived") +@router.get( + "/message/unreceived", + response_model=UnreceivedMsgResponse, + dependencies=[Depends(verify_token)], +) async def get_unreceived_msgs( receiver_id: str, background_tasks: BackgroundTasks, @@ -120,11 +158,12 @@ async def get_unreceived_msgs( all_msg_attachments.extend(msg.attachments) json_msgs.append(msg.to_dict()) - background_tasks.add_task( - send_image_by_websocket, - receiver_id, - all_msg_attachments, - ) + if all_msg_attachments: + background_tasks.add_task( + send_image_by_websocket, + receiver_id, + all_msg_attachments, + ) return { "code": 10900, @@ -132,13 +171,48 @@ async def get_unreceived_msgs( "data": json_msgs, } except Exception as e: - return {} + print(e) + return {"code": 9999, "msg": "Server Error"} + + +class UploadAttachment(BaseModel): + event: str + senderId: str + receiverId: str | None = None + receiverIds: list[str] | None = None + filename: str + tempFilename: str + totalChunkNum: int + currentChunkNum: int + bytes: list[int] + + +@router.post("/message/attachment") +async def upload_attachment(data: UploadAttachment): + data = data.model_dump() + if data.get("receiverId"): + await ws_manager.send_to_another(data["receiverId"], data) + else: + await ws_manager.broadcast(data["receiverIds"], data) + + return {"code": 10900, "msg": "Ok"} async def send_image_by_websocket(receiver_id: str, attachments: list[str]): for attachment in attachments: - byte_array = await read_chat_file() - await ws_manager.send_to_another( - receiver_id, - {"event": "chat-image", "filename": attachment, "bytes": byte_array}, - ) + for ( + current_chunk_num, + total_chunk_num, + byte_array, + ) in await async_read_chat_file(attachment): + await ws_manager.send_to_another( + receiver_id, + { + "event": "chat-image", + "filename": attachment, + "tempFilename": f"temp/{attachment}-${total_chunk_num}-${current_chunk_num}", + "totalChunkNum": total_chunk_num, + "currentChunkNum": current_chunk_num, + "bytes": byte_array, + }, + ) diff --git a/src/utils/static_file.py b/src/utils/static_file.py index 6b5f245..44aed22 100755 --- a/src/utils/static_file.py +++ b/src/utils/static_file.py @@ -1,7 +1,8 @@ import os import random import array -from typing import Literal +import math +from typing import Literal, Tuple from pathlib import Path from datetime import datetime from zipfile import ZipFile @@ -37,6 +38,8 @@ alphabet = [ "z", ] +CHUNK_SIZE = 1024 * 1024 * 1 + def create_avatar_dir(type: Literal["user", "group_chat"], dir_name: str) -> Path: if type == "user": @@ -72,25 +75,82 @@ def create_zip_file(filenames: list[str], file_type: Literal["avatars"]) -> Path async def write_chat_file( - file_path: str, - file_data: list[int], + msg: dict, file_type: Literal["image", "video"], ): - if (Path(os.getcwd()) / "static" / "chat" / file_type / file_path).exists(): + filename = msg["filename"] + total_chunk_num = msg["totalChunkNum"] + temp_filename = msg["tempFilename"] + + if (Path(os.getcwd()) / "static" / "chat" / file_type / filename).exists(): return match file_type: case "image": - sub_dir = file_path.split("/")[0] + sub_dir = filename.split("/")[0] chat_image_dir = Path(os.getcwd()) / "static" / "chat" / "images" if not (chat_image_dir / sub_dir).exists(): os.makedirs(chat_image_dir / sub_dir) - async with aiofiles.open(chat_image_dir / file_path, "wb") as f: - await f.write(bytearray(file_data)) + temp_image_dir = Path(os.getcwd()) / "static" / "temp" + if not (temp_image_dir / sub_dir).exists(): + os.makedirs(temp_image_dir / sub_dir) + + if total_chunk_num == 1: + async with aiofiles.open(chat_image_dir / filename, "wb") as file: + await file.write(bytearray(msg["bytes"])) + return + + if not (temp_image_dir.parent / temp_filename).exists(): + try: + async with aiofiles.open( + temp_image_dir.parent / temp_filename, "wb" + ) as f: + await f.write(bytearray(msg["bytes"])) + except Exception as e: + print(f"write temp file fail with error: {e}") + + temp_file_path_exist: list[bool] = [ + (temp_image_dir / f"{filename}-{total_chunk_num}-{i}").exists() + for i in range(total_chunk_num) + ] + + # assemble the file when all chunk is arrived + if all(temp_file_path_exist): + async with aiofiles.open(chat_image_dir / filename, "ab+") as file: + for i in range(total_chunk_num): + async with aiofiles.open( + temp_image_dir / f"{filename}-{total_chunk_num}-{i}", "rb" + ) as chunk_file: + bytes_content = await chunk_file.read() + await file.write(bytes_content) + os.remove(temp_image_dir / f"{filename}-{total_chunk_num}-{i}") case "video": pass -def read_chat_file(file_path: str) -> list[int]: +def read_chat_file(filename: str) -> Tuple[int, int, list[int]]: + file_suffix: str = filename.split(".")[1] + + if file_suffix == "png": + file_type = "images" + else: + file_type = "videos" + + p = Path(os.getcwd()) / "static" / "chat" / file_type / filename + file_size = p.stat().st_size + total_chunk_num = math.ceil(file_size / CHUNK_SIZE) + + if not p.exists(): + yield -1, 0, [] + else: + file = open(p, "rb") + for i in range(total_chunk_num): + byte_data = file.read(CHUNK_SIZE) + yield i, total_chunk_num, array.array("B", byte_data).tolist() + + file.close() + + +async def async_read_chat_file(file_path: str) -> Tuple[int, int, list[int]]: file_suffix: str = file_path.split(".")[1] if file_suffix == "png": @@ -99,10 +159,13 @@ def read_chat_file(file_path: str) -> list[int]: file_type = "videos" p = Path(os.getcwd()) / "static" / "chat" / file_type / file_path + file_size = p.stat().st_size + total_chunk_num = math.ceil(file_size / CHUNK_SIZE) if not p.exists(): - return [] + yield -1, 0, [] else: - with open(p, "rb") as f: - byte_data = f.read() - return array.array("B", byte_data).tolist() + async with aiofiles.open(p, "rb") as f: + for i in range(total_chunk_num): + byte_data = await f.read(CHUNK_SIZE) + yield i, total_chunk_num, array.array("B", byte_data).tolist() diff --git a/src/utils/web_socket.py b/src/utils/web_socket.py index 11a4185..3de897f 100644 --- a/src/utils/web_socket.py +++ b/src/utils/web_socket.py @@ -30,6 +30,7 @@ class WebSocketManager: if self.active_socket.get(another): socket = self.active_socket.get(another) await socket.send_json(msg) + else: if msg["event"] == "friend-chat-msg": await insert_unreceived_msg( @@ -43,10 +44,26 @@ class WebSocketManager: msg["dateTime"], msg["isShowTime"], ) - elif msg["event"] == "chat-image": - await write_chat_file(msg["filename"], msg["bytes"], "image") + if msg["event"] == "chat-image": + await write_chat_file(msg, "image") + if self.active_socket.get(msg["senderId"]): + ws = self.active_socket.get(msg["senderId"]) + await ws.send_json( + { + "event": "chat-image-send-ok", + "chatType": 0, + "receiverId": another, + "currentChunkNum": msg["currentChunkNum"], + "totalChunkNum": msg["totalChunkNum"], + "filename": msg["filename"], + } + ) - async def broadcast(self, msg: dict, receiver_ids: list[str]): + async def broadcast( + self, + receiver_ids: list[str], + msg: dict, + ): for receiver_id in receiver_ids: socket = self.active_socket.get(receiver_id) if socket: @@ -62,10 +79,22 @@ class WebSocketManager: msg["text"], msg["attachments"], msg["dateTime"], - msg["is_show_time"], + msg["isShowTime"], group_chat_id=msg["groupChatId"], - name=msg["name"], + nickname=msg["nickname"], + remark_in_group_chat=msg["remarkInGroupChat"], avatar=msg["avatar"], ) - elif msg["event"] == "chat-image": - await write_chat_file(msg["filename"], msg["bytes"], "image") + if msg["event"] == "chat-image": + await write_chat_file(msg, "image") + if self.active_socket.get(msg["senderId"]): + await self.active_socket["senderId"].send_json( + { + "event": "chat-image-send-ok", + "chatType": 1, + "receiverIds": receiver_ids, + "currentChunkNum": msg["currentChunkNum"], + "totalChunkNum": msg["totalChunkNum"], + "filename": msg["filename"], + } + )