implement slice transmitting of chat image

main
htylight 2023-10-01 18:43:25 +08:00
parent d06c061e02
commit 1c00eae62e
9 changed files with 258 additions and 59 deletions

View File

@ -1,8 +1,8 @@
"""create unreceived_msg table """create unrecieved_msg table
Revision ID: ef4cbdcc711b Revision ID: 5ac889c1715f
Revises: 4947792c7572 Revises: 4947792c7572
Create Date: 2023-09-16 10:47:50.809077 Create Date: 2023-09-23 19:48:51.530634
""" """
from alembic import op from alembic import op
@ -10,7 +10,7 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = 'ef4cbdcc711b' revision = '5ac889c1715f'
down_revision = '4947792c7572' down_revision = '4947792c7572'
branch_labels = None branch_labels = None
depends_on = None depends_on = None
@ -19,19 +19,21 @@ depends_on = None
def upgrade() -> None: def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.create_table('unreceived_msg', op.create_table('unreceived_msg',
sa.Column('msg_id', sa.String(length=16), nullable=False), sa.Column('id', sa.Integer(), nullable=False),
sa.Column('msg_id', sa.String(length=18), nullable=False),
sa.Column('event', sa.String(), nullable=False), sa.Column('event', sa.String(), nullable=False),
sa.Column('type', sa.String(), nullable=False), sa.Column('type', sa.String(), nullable=False),
sa.Column('receiver_id', sa.String(length=26), nullable=False), sa.Column('receiver_id', sa.String(length=26), nullable=False),
sa.Column('sender_id', sa.String(length=26), nullable=False), sa.Column('sender_id', sa.String(length=26), nullable=False),
sa.Column('group_chat_id', sa.String(length=11), nullable=True), sa.Column('group_chat_id', sa.String(length=11), nullable=True),
sa.Column('name', sa.String(), nullable=True), sa.Column('remark_in_group_chat', sa.String(), nullable=True),
sa.Column('nickname', sa.String(), nullable=True),
sa.Column('avatar', sa.String(), nullable=True), sa.Column('avatar', sa.String(), nullable=True),
sa.Column('text', sa.String(), nullable=False), sa.Column('text', sa.String(), nullable=False),
sa.Column('attachments', sa.ARRAY(sa.String()), nullable=False), sa.Column('attachments', sa.ARRAY(sa.String()), nullable=False),
sa.Column('date_time', sa.String(), nullable=False), sa.Column('date_time', sa.String(), nullable=False),
sa.Column('is_show_time', sa.Boolean(), nullable=False), sa.Column('is_show_time', sa.Boolean(), nullable=False),
sa.PrimaryKeyConstraint('msg_id') sa.PrimaryKeyConstraint('id')
) )
# ### end Alembic commands ### # ### end Alembic commands ###

View File

@ -29,7 +29,7 @@ async def insert_group_chat(supervisor: str, members: list[str]) -> GroupChat:
select(Contact).where(Contact.user_id.in_(members)) select(Contact).where(Contact.user_id.in_(members))
) )
for contact in contact_res.all(): for contact in contact_res.all():
contact.group_chats[id] = {"nameRemark": "", "myRemark": ""} contact.group_chats[id] = {"groupChatRemark": "", "remarkInGroupChat": ""}
flag_modified(contact, "group_chats") flag_modified(contact, "group_chats")
session.add_all(contact_res.all()) session.add_all(contact_res.all())
@ -57,7 +57,10 @@ async def insert_group_chat_members(group_chat_id: str, members: list[str]):
contact: Contact = ( contact: Contact = (
await session.scalars(select(Contact).where(Contact.user_id == member)) await session.scalars(select(Contact).where(Contact.user_id == member))
).one() ).one()
contact.group_chats[group_chat_id] = {"myRemark": "", "nameRemark": ""} contact.group_chats[group_chat_id] = {
"remarkInGroupChat": "",
"groupChatRemark": "",
}
flag_modified(contact, "group_chats") flag_modified(contact, "group_chats")
session.add(contact) session.add(contact)
@ -75,7 +78,7 @@ async def select_member_name_avatar(member_id: str, is_friend: bool) -> dict | T
select(Contact.group_chats).where(Contact.user_id == member_id) select(Contact.group_chats).where(Contact.user_id == member_id)
) )
await session.close() await session.close()
# {'81906574618': {'myRemark': '', 'nameRemark': ''}} # {'81906574618': {'remarkInGroupChat': '', 'groupChatRemark': ''}}
return res.one() return res.one()
else: else:
res = await session.execute( res = await session.execute(
@ -84,7 +87,7 @@ async def select_member_name_avatar(member_id: str, is_friend: bool) -> dict | T
.where(UserProfile.user_id == member_id) .where(UserProfile.user_id == member_id)
) )
await session.close() await session.close()
# ('htylight', 'cznowoyn1692502503.png', {'81906574618': {'myRemark': '', 'nameRemark': ''},}) # ('htylight', 'cznowoyn1692502503.png', {'81906574618': {'remarkInGroupChat': '', 'groupChatRemark': ''},})
return res.one() return res.one()
@ -99,7 +102,7 @@ async def select_full_profile(group_chat_id: str) -> Tuple[GroupChat, list[Tuple
UserProfile.user_id, UserProfile.user_id,
UserProfile.nickname, UserProfile.nickname,
UserProfile.avatar, UserProfile.avatar,
Contact.group_chats[group_chat_id]["myRemark"], Contact.group_chats[group_chat_id]["remarkInGroupChat"],
) )
.join(Contact, UserProfile.user_id == Contact.user_id) .join(Contact, UserProfile.user_id == Contact.user_id)
.where(UserProfile.user_id.in_(members)) .where(UserProfile.user_id.in_(members))
@ -147,7 +150,7 @@ async def update_group_remark(user_id: str, group_chat_id: str, new_remark: str)
contact: Contact = ( contact: Contact = (
await session.scalars(select(Contact).where(Contact.user_id == user_id)) await session.scalars(select(Contact).where(Contact.user_id == user_id))
).one() ).one()
contact.group_chats[group_chat_id]["nameRemark"] = new_remark contact.group_chats[group_chat_id]["groupChatRemark"] = new_remark
flag_modified(contact, "group_chats") flag_modified(contact, "group_chats")
session.add(contact) session.add(contact)
await session.commit() await session.commit()
@ -163,7 +166,7 @@ async def update_my_remark(user_id: str, group_chat_id: str, new_my_remark: str)
contact: Contact = ( contact: Contact = (
await session.scalars(select(Contact).where(Contact.user_id == user_id)) await session.scalars(select(Contact).where(Contact.user_id == user_id))
).one() ).one()
contact.group_chats[group_chat_id]["myRemark"] = new_my_remark contact.group_chats[group_chat_id]["remarkInGroupChat"] = new_my_remark
flag_modified(contact, "group_chats") flag_modified(contact, "group_chats")
session.add(contact) session.add(contact)
await session.commit() await session.commit()

View File

@ -17,7 +17,8 @@ async def insert_unreceived_msg(
date_time: str, date_time: str,
is_show_time, is_show_time,
group_chat_id: str = None, group_chat_id: str = None,
name: str = None, nickname: str = None,
remark_in_group_chat: str = None,
avatar: str = None, avatar: str = None,
): ):
session = async_session() session = async_session()
@ -30,7 +31,8 @@ async def insert_unreceived_msg(
sender_id=sender_id, sender_id=sender_id,
receiver_id=receiver_id, receiver_id=receiver_id,
group_chat_id=group_chat_id, group_chat_id=group_chat_id,
name=name, nickname=nickname,
remark_in_group_chat=remark_in_group_chat,
avatar=avatar, avatar=avatar,
text=text, text=text,
attachments=attachments, attachments=attachments,

View File

@ -7,5 +7,5 @@ class FriendSetting(TypedDict):
class GroupChatSetting(TypedDict): class GroupChatSetting(TypedDict):
nameRemark: str | None groupChatRemark: str | None
myRemark: str | None remarkInGroupChat: str | None

View File

@ -168,13 +168,15 @@ class GroupChat(Base):
class UnreceivedMsg(Base): class UnreceivedMsg(Base):
__tablename__ = "unreceived_msg" __tablename__ = "unreceived_msg"
msg_id: Mapped[str] = mapped_column(String(16), primary_key=True) id: Mapped[int] = mapped_column(Integer, primary_key=True)
msg_id: Mapped[str] = mapped_column(String(18))
event: Mapped[str] = mapped_column(String) event: Mapped[str] = mapped_column(String)
type: Mapped[str] = mapped_column(String) type: Mapped[str] = mapped_column(String)
receiver_id: Mapped[str] = mapped_column(String(26)) receiver_id: Mapped[str] = mapped_column(String(26))
sender_id: Mapped[str] = mapped_column(String(26)) sender_id: Mapped[str] = mapped_column(String(26))
group_chat_id: Mapped[str] = mapped_column(String(11), nullable=True) group_chat_id: Mapped[str] = mapped_column(String(11), nullable=True)
name: Mapped[str] = mapped_column(String, nullable=True) remark_in_group_chat: Mapped[str] = mapped_column(String, nullable=True)
nickname: Mapped[str] = mapped_column(String, nullable=True)
avatar: Mapped[str] = mapped_column(String, nullable=True) avatar: Mapped[str] = mapped_column(String, nullable=True)
text: Mapped[str] = mapped_column(String) text: Mapped[str] = mapped_column(String)
attachments: Mapped[list[str]] = mapped_column(ARRAY(String)) attachments: Mapped[list[str]] = mapped_column(ARRAY(String))
@ -188,6 +190,9 @@ class UnreceivedMsg(Base):
"type": self.type, "type": self.type,
"senderId": self.sender_id, "senderId": self.sender_id,
"groupChatId": self.group_chat_id, "groupChatId": self.group_chat_id,
"nickname": self.nickname,
"remarkInGroupChat": self.remark_in_group_chat,
"avatar": self.avatar,
"text": self.text, "text": self.text,
"attachments": self.attachments, "attachments": self.attachments,
"dateTime": self.date_time, "dateTime": self.date_time,

View File

@ -0,0 +1,21 @@
from .base import BaseModel, BaseResponseModel
class _UnreceivedMsg(BaseModel):
msgId: str
event: str
type: str
receiver_id: str
sender_id: str
group_chat_id: str
nickname: str | None
remarkInGroupChat: str | None
avatar: str | None
text: str
attachments: list[str]
date_time: str
is_show_time: bool
class UnreceivedMsgResponse(BaseResponseModel):
data: _UnreceivedMsg | None = None

View File

@ -1,12 +1,15 @@
import threading import threading
import asyncio import asyncio
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, BackgroundTasks from fastapi import APIRouter, WebSocket, WebSocketDisconnect, BackgroundTasks, Depends
from pydantic import BaseModel
from ..dependencies import verify_token
from ..utils.web_socket import WebSocketManager from ..utils.web_socket import WebSocketManager
from ..utils.static_file import read_chat_file from ..utils.static_file import read_chat_file, async_read_chat_file
from ..crud import unreceived_msg_crud from ..crud import unreceived_msg_crud
from ..database.models import UnreceivedMsg from ..database.models import UnreceivedMsg
from ..response_models.message_response import UnreceivedMsgResponse
router = APIRouter(tags=["message"]) router = APIRouter(tags=["message"])
@ -26,14 +29,20 @@ async def push_unsent_messages():
msgs: list[UnreceivedMsg] = unreceived_msg_crud.select_msgs(user_id) msgs: list[UnreceivedMsg] = unreceived_msg_crud.select_msgs(user_id)
for msg in msgs: for msg in msgs:
await ws_manager.send_to_another(user_id, msg.to_dict()) await ws_manager.send_to_another(user_id, msg.to_dict())
if msg.attachments: for attachment in msg.attachments:
for attachment in msg.attachments: for (
byte_array = read_chat_file(attachment) current_chunk_num,
total_chunk_num,
byte_array,
) in read_chat_file(attachment):
await ws_manager.send_to_another( await ws_manager.send_to_another(
user_id, user_id,
{ {
"event": "chat-image", "event": "chat-image",
"filename": attachment, "filename": attachment,
"tempFilename": f"temp/{attachment}-${total_chunk_num}-${current_chunk_num}",
"totalChunkNum": total_chunk_num,
"currentChunkNum": current_chunk_num,
"bytes": byte_array, "bytes": byte_array,
}, },
) )
@ -76,11 +85,22 @@ async def connect_websocket(websocket: WebSocket, user_id: str, is_reconnect: bo
try: try:
while True: while True:
data = await websocket.receive_json() data = await websocket.receive_json()
match data["event"]: match data["event"]:
case "ping": case "ping":
await ws_manager.active_socket[user_id].send_json({"type": "pong"}) await ws_manager.active_socket[user_id].send_json({"type": "pong"})
case "friend-chat-msg": case "friend-chat-msg":
await ws_manager.send_to_another(data["receiverId"], data) await ws_manager.send_to_another(data["receiverId"], data)
if len(data["attachments"]) > 0:
await ws_manager.send_to_another(
data["senderId"],
{
"event": "pull-chat-image",
"chatType": 0,
"attachments": data["attachments"],
"receiverId": data["receiverId"],
},
)
case "apply-friend": case "apply-friend":
await ws_manager.send_to_another(data["recipient"], data) await ws_manager.send_to_another(data["recipient"], data)
case "friend-added" | "friend-deleted": case "friend-added" | "friend-deleted":
@ -91,18 +111,36 @@ async def connect_websocket(websocket: WebSocket, user_id: str, is_reconnect: bo
else: else:
receiver_ids = data["receiverIds"] receiver_ids = data["receiverIds"]
del data["receiverIds"] del data["receiverIds"]
await ws_manager.broadcast(data, receiver_ids) await ws_manager.broadcast(receiver_ids, data)
case "group-chat-creation" | "group-chat-msg": case "group-chat-creation":
receiver_ids = data["receiverIds"] receiver_ids = data["receiverIds"]
del data["receiverIds"] del data["receiverIds"]
await ws_manager.broadcast(data, receiver_ids) await ws_manager.broadcast(receiver_ids, data)
case "group-chat-msg":
receiver_ids = data["receiverIds"]
del data["receiverIds"]
await ws_manager.broadcast(receiver_ids, data)
if len(data["attachments"]) > 0:
await ws_manager.send_to_another(
data["senderId"],
{
"event": "pull-chat-image",
"chatType": 1,
"attachments": data["attachments"],
"receiverIds": receiver_ids,
},
)
except WebSocketDisconnect: except WebSocketDisconnect:
print(f"{user_id} disconnect") print(f"{user_id} disconnect")
ws_manager.disconnect(user_id) ws_manager.disconnect(user_id)
@router.get("/message/unreceived") @router.get(
"/message/unreceived",
response_model=UnreceivedMsgResponse,
dependencies=[Depends(verify_token)],
)
async def get_unreceived_msgs( async def get_unreceived_msgs(
receiver_id: str, receiver_id: str,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
@ -120,11 +158,12 @@ async def get_unreceived_msgs(
all_msg_attachments.extend(msg.attachments) all_msg_attachments.extend(msg.attachments)
json_msgs.append(msg.to_dict()) json_msgs.append(msg.to_dict())
background_tasks.add_task( if all_msg_attachments:
send_image_by_websocket, background_tasks.add_task(
receiver_id, send_image_by_websocket,
all_msg_attachments, receiver_id,
) all_msg_attachments,
)
return { return {
"code": 10900, "code": 10900,
@ -132,13 +171,48 @@ async def get_unreceived_msgs(
"data": json_msgs, "data": json_msgs,
} }
except Exception as e: except Exception as e:
return {} print(e)
return {"code": 9999, "msg": "Server Error"}
class UploadAttachment(BaseModel):
event: str
senderId: str
receiverId: str | None = None
receiverIds: list[str] | None = None
filename: str
tempFilename: str
totalChunkNum: int
currentChunkNum: int
bytes: list[int]
@router.post("/message/attachment")
async def upload_attachment(data: UploadAttachment):
data = data.model_dump()
if data.get("receiverId"):
await ws_manager.send_to_another(data["receiverId"], data)
else:
await ws_manager.broadcast(data["receiverIds"], data)
return {"code": 10900, "msg": "Ok"}
async def send_image_by_websocket(receiver_id: str, attachments: list[str]): async def send_image_by_websocket(receiver_id: str, attachments: list[str]):
for attachment in attachments: for attachment in attachments:
byte_array = await read_chat_file() for (
await ws_manager.send_to_another( current_chunk_num,
receiver_id, total_chunk_num,
{"event": "chat-image", "filename": attachment, "bytes": byte_array}, byte_array,
) ) in await async_read_chat_file(attachment):
await ws_manager.send_to_another(
receiver_id,
{
"event": "chat-image",
"filename": attachment,
"tempFilename": f"temp/{attachment}-${total_chunk_num}-${current_chunk_num}",
"totalChunkNum": total_chunk_num,
"currentChunkNum": current_chunk_num,
"bytes": byte_array,
},
)

View File

@ -1,7 +1,8 @@
import os import os
import random import random
import array import array
from typing import Literal import math
from typing import Literal, Tuple
from pathlib import Path from pathlib import Path
from datetime import datetime from datetime import datetime
from zipfile import ZipFile from zipfile import ZipFile
@ -37,6 +38,8 @@ alphabet = [
"z", "z",
] ]
CHUNK_SIZE = 1024 * 1024 * 1
def create_avatar_dir(type: Literal["user", "group_chat"], dir_name: str) -> Path: def create_avatar_dir(type: Literal["user", "group_chat"], dir_name: str) -> Path:
if type == "user": if type == "user":
@ -72,25 +75,82 @@ def create_zip_file(filenames: list[str], file_type: Literal["avatars"]) -> Path
async def write_chat_file( async def write_chat_file(
file_path: str, msg: dict,
file_data: list[int],
file_type: Literal["image", "video"], file_type: Literal["image", "video"],
): ):
if (Path(os.getcwd()) / "static" / "chat" / file_type / file_path).exists(): filename = msg["filename"]
total_chunk_num = msg["totalChunkNum"]
temp_filename = msg["tempFilename"]
if (Path(os.getcwd()) / "static" / "chat" / file_type / filename).exists():
return return
match file_type: match file_type:
case "image": case "image":
sub_dir = file_path.split("/")[0] sub_dir = filename.split("/")[0]
chat_image_dir = Path(os.getcwd()) / "static" / "chat" / "images" chat_image_dir = Path(os.getcwd()) / "static" / "chat" / "images"
if not (chat_image_dir / sub_dir).exists(): if not (chat_image_dir / sub_dir).exists():
os.makedirs(chat_image_dir / sub_dir) os.makedirs(chat_image_dir / sub_dir)
async with aiofiles.open(chat_image_dir / file_path, "wb") as f: temp_image_dir = Path(os.getcwd()) / "static" / "temp"
await f.write(bytearray(file_data)) if not (temp_image_dir / sub_dir).exists():
os.makedirs(temp_image_dir / sub_dir)
if total_chunk_num == 1:
async with aiofiles.open(chat_image_dir / filename, "wb") as file:
await file.write(bytearray(msg["bytes"]))
return
if not (temp_image_dir.parent / temp_filename).exists():
try:
async with aiofiles.open(
temp_image_dir.parent / temp_filename, "wb"
) as f:
await f.write(bytearray(msg["bytes"]))
except Exception as e:
print(f"write temp file fail with error: {e}")
temp_file_path_exist: list[bool] = [
(temp_image_dir / f"{filename}-{total_chunk_num}-{i}").exists()
for i in range(total_chunk_num)
]
# assemble the file when all chunk is arrived
if all(temp_file_path_exist):
async with aiofiles.open(chat_image_dir / filename, "ab+") as file:
for i in range(total_chunk_num):
async with aiofiles.open(
temp_image_dir / f"{filename}-{total_chunk_num}-{i}", "rb"
) as chunk_file:
bytes_content = await chunk_file.read()
await file.write(bytes_content)
os.remove(temp_image_dir / f"{filename}-{total_chunk_num}-{i}")
case "video": case "video":
pass pass
def read_chat_file(file_path: str) -> list[int]: def read_chat_file(filename: str) -> Tuple[int, int, list[int]]:
file_suffix: str = filename.split(".")[1]
if file_suffix == "png":
file_type = "images"
else:
file_type = "videos"
p = Path(os.getcwd()) / "static" / "chat" / file_type / filename
file_size = p.stat().st_size
total_chunk_num = math.ceil(file_size / CHUNK_SIZE)
if not p.exists():
yield -1, 0, []
else:
file = open(p, "rb")
for i in range(total_chunk_num):
byte_data = file.read(CHUNK_SIZE)
yield i, total_chunk_num, array.array("B", byte_data).tolist()
file.close()
async def async_read_chat_file(file_path: str) -> Tuple[int, int, list[int]]:
file_suffix: str = file_path.split(".")[1] file_suffix: str = file_path.split(".")[1]
if file_suffix == "png": if file_suffix == "png":
@ -99,10 +159,13 @@ def read_chat_file(file_path: str) -> list[int]:
file_type = "videos" file_type = "videos"
p = Path(os.getcwd()) / "static" / "chat" / file_type / file_path p = Path(os.getcwd()) / "static" / "chat" / file_type / file_path
file_size = p.stat().st_size
total_chunk_num = math.ceil(file_size / CHUNK_SIZE)
if not p.exists(): if not p.exists():
return [] yield -1, 0, []
else: else:
with open(p, "rb") as f: async with aiofiles.open(p, "rb") as f:
byte_data = f.read() for i in range(total_chunk_num):
return array.array("B", byte_data).tolist() byte_data = await f.read(CHUNK_SIZE)
yield i, total_chunk_num, array.array("B", byte_data).tolist()

View File

@ -30,6 +30,7 @@ class WebSocketManager:
if self.active_socket.get(another): if self.active_socket.get(another):
socket = self.active_socket.get(another) socket = self.active_socket.get(another)
await socket.send_json(msg) await socket.send_json(msg)
else: else:
if msg["event"] == "friend-chat-msg": if msg["event"] == "friend-chat-msg":
await insert_unreceived_msg( await insert_unreceived_msg(
@ -43,10 +44,26 @@ class WebSocketManager:
msg["dateTime"], msg["dateTime"],
msg["isShowTime"], msg["isShowTime"],
) )
elif msg["event"] == "chat-image": if msg["event"] == "chat-image":
await write_chat_file(msg["filename"], msg["bytes"], "image") await write_chat_file(msg, "image")
if self.active_socket.get(msg["senderId"]):
ws = self.active_socket.get(msg["senderId"])
await ws.send_json(
{
"event": "chat-image-send-ok",
"chatType": 0,
"receiverId": another,
"currentChunkNum": msg["currentChunkNum"],
"totalChunkNum": msg["totalChunkNum"],
"filename": msg["filename"],
}
)
async def broadcast(self, msg: dict, receiver_ids: list[str]): async def broadcast(
self,
receiver_ids: list[str],
msg: dict,
):
for receiver_id in receiver_ids: for receiver_id in receiver_ids:
socket = self.active_socket.get(receiver_id) socket = self.active_socket.get(receiver_id)
if socket: if socket:
@ -62,10 +79,22 @@ class WebSocketManager:
msg["text"], msg["text"],
msg["attachments"], msg["attachments"],
msg["dateTime"], msg["dateTime"],
msg["is_show_time"], msg["isShowTime"],
group_chat_id=msg["groupChatId"], group_chat_id=msg["groupChatId"],
name=msg["name"], nickname=msg["nickname"],
remark_in_group_chat=msg["remarkInGroupChat"],
avatar=msg["avatar"], avatar=msg["avatar"],
) )
elif msg["event"] == "chat-image": if msg["event"] == "chat-image":
await write_chat_file(msg["filename"], msg["bytes"], "image") await write_chat_file(msg, "image")
if self.active_socket.get(msg["senderId"]):
await self.active_socket["senderId"].send_json(
{
"event": "chat-image-send-ok",
"chatType": 1,
"receiverIds": receiver_ids,
"currentChunkNum": msg["currentChunkNum"],
"totalChunkNum": msg["totalChunkNum"],
"filename": msg["filename"],
}
)