implement slice transmitting of chat image

main
htylight 2023-10-01 18:43:25 +08:00
parent d06c061e02
commit 1c00eae62e
9 changed files with 258 additions and 59 deletions

View File

@ -1,8 +1,8 @@
"""create unreceived_msg table
"""create unrecieved_msg table
Revision ID: ef4cbdcc711b
Revision ID: 5ac889c1715f
Revises: 4947792c7572
Create Date: 2023-09-16 10:47:50.809077
Create Date: 2023-09-23 19:48:51.530634
"""
from alembic import op
@ -10,7 +10,7 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'ef4cbdcc711b'
revision = '5ac889c1715f'
down_revision = '4947792c7572'
branch_labels = None
depends_on = None
@ -19,19 +19,21 @@ depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('unreceived_msg',
sa.Column('msg_id', sa.String(length=16), nullable=False),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('msg_id', sa.String(length=18), nullable=False),
sa.Column('event', sa.String(), nullable=False),
sa.Column('type', sa.String(), nullable=False),
sa.Column('receiver_id', sa.String(length=26), nullable=False),
sa.Column('sender_id', sa.String(length=26), nullable=False),
sa.Column('group_chat_id', sa.String(length=11), nullable=True),
sa.Column('name', sa.String(), nullable=True),
sa.Column('remark_in_group_chat', sa.String(), nullable=True),
sa.Column('nickname', sa.String(), nullable=True),
sa.Column('avatar', sa.String(), nullable=True),
sa.Column('text', sa.String(), nullable=False),
sa.Column('attachments', sa.ARRAY(sa.String()), nullable=False),
sa.Column('date_time', sa.String(), nullable=False),
sa.Column('is_show_time', sa.Boolean(), nullable=False),
sa.PrimaryKeyConstraint('msg_id')
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###

View File

@ -29,7 +29,7 @@ async def insert_group_chat(supervisor: str, members: list[str]) -> GroupChat:
select(Contact).where(Contact.user_id.in_(members))
)
for contact in contact_res.all():
contact.group_chats[id] = {"nameRemark": "", "myRemark": ""}
contact.group_chats[id] = {"groupChatRemark": "", "remarkInGroupChat": ""}
flag_modified(contact, "group_chats")
session.add_all(contact_res.all())
@ -57,7 +57,10 @@ async def insert_group_chat_members(group_chat_id: str, members: list[str]):
contact: Contact = (
await session.scalars(select(Contact).where(Contact.user_id == member))
).one()
contact.group_chats[group_chat_id] = {"myRemark": "", "nameRemark": ""}
contact.group_chats[group_chat_id] = {
"remarkInGroupChat": "",
"groupChatRemark": "",
}
flag_modified(contact, "group_chats")
session.add(contact)
@ -75,7 +78,7 @@ async def select_member_name_avatar(member_id: str, is_friend: bool) -> dict | T
select(Contact.group_chats).where(Contact.user_id == member_id)
)
await session.close()
# {'81906574618': {'myRemark': '', 'nameRemark': ''}}
# {'81906574618': {'remarkInGroupChat': '', 'groupChatRemark': ''}}
return res.one()
else:
res = await session.execute(
@ -84,7 +87,7 @@ async def select_member_name_avatar(member_id: str, is_friend: bool) -> dict | T
.where(UserProfile.user_id == member_id)
)
await session.close()
# ('htylight', 'cznowoyn1692502503.png', {'81906574618': {'myRemark': '', 'nameRemark': ''},})
# ('htylight', 'cznowoyn1692502503.png', {'81906574618': {'remarkInGroupChat': '', 'groupChatRemark': ''},})
return res.one()
@ -99,7 +102,7 @@ async def select_full_profile(group_chat_id: str) -> Tuple[GroupChat, list[Tuple
UserProfile.user_id,
UserProfile.nickname,
UserProfile.avatar,
Contact.group_chats[group_chat_id]["myRemark"],
Contact.group_chats[group_chat_id]["remarkInGroupChat"],
)
.join(Contact, UserProfile.user_id == Contact.user_id)
.where(UserProfile.user_id.in_(members))
@ -147,7 +150,7 @@ async def update_group_remark(user_id: str, group_chat_id: str, new_remark: str)
contact: Contact = (
await session.scalars(select(Contact).where(Contact.user_id == user_id))
).one()
contact.group_chats[group_chat_id]["nameRemark"] = new_remark
contact.group_chats[group_chat_id]["groupChatRemark"] = new_remark
flag_modified(contact, "group_chats")
session.add(contact)
await session.commit()
@ -163,7 +166,7 @@ async def update_my_remark(user_id: str, group_chat_id: str, new_my_remark: str)
contact: Contact = (
await session.scalars(select(Contact).where(Contact.user_id == user_id))
).one()
contact.group_chats[group_chat_id]["myRemark"] = new_my_remark
contact.group_chats[group_chat_id]["remarkInGroupChat"] = new_my_remark
flag_modified(contact, "group_chats")
session.add(contact)
await session.commit()

View File

@ -17,7 +17,8 @@ async def insert_unreceived_msg(
date_time: str,
is_show_time,
group_chat_id: str = None,
name: str = None,
nickname: str = None,
remark_in_group_chat: str = None,
avatar: str = None,
):
session = async_session()
@ -30,7 +31,8 @@ async def insert_unreceived_msg(
sender_id=sender_id,
receiver_id=receiver_id,
group_chat_id=group_chat_id,
name=name,
nickname=nickname,
remark_in_group_chat=remark_in_group_chat,
avatar=avatar,
text=text,
attachments=attachments,

View File

@ -7,5 +7,5 @@ class FriendSetting(TypedDict):
class GroupChatSetting(TypedDict):
nameRemark: str | None
myRemark: str | None
groupChatRemark: str | None
remarkInGroupChat: str | None

View File

@ -168,13 +168,15 @@ class GroupChat(Base):
class UnreceivedMsg(Base):
__tablename__ = "unreceived_msg"
msg_id: Mapped[str] = mapped_column(String(16), primary_key=True)
id: Mapped[int] = mapped_column(Integer, primary_key=True)
msg_id: Mapped[str] = mapped_column(String(18))
event: Mapped[str] = mapped_column(String)
type: Mapped[str] = mapped_column(String)
receiver_id: Mapped[str] = mapped_column(String(26))
sender_id: Mapped[str] = mapped_column(String(26))
group_chat_id: Mapped[str] = mapped_column(String(11), nullable=True)
name: Mapped[str] = mapped_column(String, nullable=True)
remark_in_group_chat: Mapped[str] = mapped_column(String, nullable=True)
nickname: Mapped[str] = mapped_column(String, nullable=True)
avatar: Mapped[str] = mapped_column(String, nullable=True)
text: Mapped[str] = mapped_column(String)
attachments: Mapped[list[str]] = mapped_column(ARRAY(String))
@ -188,6 +190,9 @@ class UnreceivedMsg(Base):
"type": self.type,
"senderId": self.sender_id,
"groupChatId": self.group_chat_id,
"nickname": self.nickname,
"remarkInGroupChat": self.remark_in_group_chat,
"avatar": self.avatar,
"text": self.text,
"attachments": self.attachments,
"dateTime": self.date_time,

View File

@ -0,0 +1,21 @@
from .base import BaseModel, BaseResponseModel
class _UnreceivedMsg(BaseModel):
msgId: str
event: str
type: str
receiver_id: str
sender_id: str
group_chat_id: str
nickname: str | None
remarkInGroupChat: str | None
avatar: str | None
text: str
attachments: list[str]
date_time: str
is_show_time: bool
class UnreceivedMsgResponse(BaseResponseModel):
data: _UnreceivedMsg | None = None

View File

@ -1,12 +1,15 @@
import threading
import asyncio
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, BackgroundTasks
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, BackgroundTasks, Depends
from pydantic import BaseModel
from ..dependencies import verify_token
from ..utils.web_socket import WebSocketManager
from ..utils.static_file import read_chat_file
from ..utils.static_file import read_chat_file, async_read_chat_file
from ..crud import unreceived_msg_crud
from ..database.models import UnreceivedMsg
from ..response_models.message_response import UnreceivedMsgResponse
router = APIRouter(tags=["message"])
@ -26,14 +29,20 @@ async def push_unsent_messages():
msgs: list[UnreceivedMsg] = unreceived_msg_crud.select_msgs(user_id)
for msg in msgs:
await ws_manager.send_to_another(user_id, msg.to_dict())
if msg.attachments:
for attachment in msg.attachments:
byte_array = read_chat_file(attachment)
for attachment in msg.attachments:
for (
current_chunk_num,
total_chunk_num,
byte_array,
) in read_chat_file(attachment):
await ws_manager.send_to_another(
user_id,
{
"event": "chat-image",
"filename": attachment,
"tempFilename": f"temp/{attachment}-${total_chunk_num}-${current_chunk_num}",
"totalChunkNum": total_chunk_num,
"currentChunkNum": current_chunk_num,
"bytes": byte_array,
},
)
@ -76,11 +85,22 @@ async def connect_websocket(websocket: WebSocket, user_id: str, is_reconnect: bo
try:
while True:
data = await websocket.receive_json()
match data["event"]:
case "ping":
await ws_manager.active_socket[user_id].send_json({"type": "pong"})
case "friend-chat-msg":
await ws_manager.send_to_another(data["receiverId"], data)
if len(data["attachments"]) > 0:
await ws_manager.send_to_another(
data["senderId"],
{
"event": "pull-chat-image",
"chatType": 0,
"attachments": data["attachments"],
"receiverId": data["receiverId"],
},
)
case "apply-friend":
await ws_manager.send_to_another(data["recipient"], data)
case "friend-added" | "friend-deleted":
@ -91,18 +111,36 @@ async def connect_websocket(websocket: WebSocket, user_id: str, is_reconnect: bo
else:
receiver_ids = data["receiverIds"]
del data["receiverIds"]
await ws_manager.broadcast(data, receiver_ids)
case "group-chat-creation" | "group-chat-msg":
await ws_manager.broadcast(receiver_ids, data)
case "group-chat-creation":
receiver_ids = data["receiverIds"]
del data["receiverIds"]
await ws_manager.broadcast(data, receiver_ids)
await ws_manager.broadcast(receiver_ids, data)
case "group-chat-msg":
receiver_ids = data["receiverIds"]
del data["receiverIds"]
await ws_manager.broadcast(receiver_ids, data)
if len(data["attachments"]) > 0:
await ws_manager.send_to_another(
data["senderId"],
{
"event": "pull-chat-image",
"chatType": 1,
"attachments": data["attachments"],
"receiverIds": receiver_ids,
},
)
except WebSocketDisconnect:
print(f"{user_id} disconnect")
ws_manager.disconnect(user_id)
@router.get("/message/unreceived")
@router.get(
"/message/unreceived",
response_model=UnreceivedMsgResponse,
dependencies=[Depends(verify_token)],
)
async def get_unreceived_msgs(
receiver_id: str,
background_tasks: BackgroundTasks,
@ -120,11 +158,12 @@ async def get_unreceived_msgs(
all_msg_attachments.extend(msg.attachments)
json_msgs.append(msg.to_dict())
background_tasks.add_task(
send_image_by_websocket,
receiver_id,
all_msg_attachments,
)
if all_msg_attachments:
background_tasks.add_task(
send_image_by_websocket,
receiver_id,
all_msg_attachments,
)
return {
"code": 10900,
@ -132,13 +171,48 @@ async def get_unreceived_msgs(
"data": json_msgs,
}
except Exception as e:
return {}
print(e)
return {"code": 9999, "msg": "Server Error"}
class UploadAttachment(BaseModel):
event: str
senderId: str
receiverId: str | None = None
receiverIds: list[str] | None = None
filename: str
tempFilename: str
totalChunkNum: int
currentChunkNum: int
bytes: list[int]
@router.post("/message/attachment")
async def upload_attachment(data: UploadAttachment):
data = data.model_dump()
if data.get("receiverId"):
await ws_manager.send_to_another(data["receiverId"], data)
else:
await ws_manager.broadcast(data["receiverIds"], data)
return {"code": 10900, "msg": "Ok"}
async def send_image_by_websocket(receiver_id: str, attachments: list[str]):
for attachment in attachments:
byte_array = await read_chat_file()
await ws_manager.send_to_another(
receiver_id,
{"event": "chat-image", "filename": attachment, "bytes": byte_array},
)
for (
current_chunk_num,
total_chunk_num,
byte_array,
) in await async_read_chat_file(attachment):
await ws_manager.send_to_another(
receiver_id,
{
"event": "chat-image",
"filename": attachment,
"tempFilename": f"temp/{attachment}-${total_chunk_num}-${current_chunk_num}",
"totalChunkNum": total_chunk_num,
"currentChunkNum": current_chunk_num,
"bytes": byte_array,
},
)

View File

@ -1,7 +1,8 @@
import os
import random
import array
from typing import Literal
import math
from typing import Literal, Tuple
from pathlib import Path
from datetime import datetime
from zipfile import ZipFile
@ -37,6 +38,8 @@ alphabet = [
"z",
]
CHUNK_SIZE = 1024 * 1024 * 1
def create_avatar_dir(type: Literal["user", "group_chat"], dir_name: str) -> Path:
if type == "user":
@ -72,25 +75,82 @@ def create_zip_file(filenames: list[str], file_type: Literal["avatars"]) -> Path
async def write_chat_file(
file_path: str,
file_data: list[int],
msg: dict,
file_type: Literal["image", "video"],
):
if (Path(os.getcwd()) / "static" / "chat" / file_type / file_path).exists():
filename = msg["filename"]
total_chunk_num = msg["totalChunkNum"]
temp_filename = msg["tempFilename"]
if (Path(os.getcwd()) / "static" / "chat" / file_type / filename).exists():
return
match file_type:
case "image":
sub_dir = file_path.split("/")[0]
sub_dir = filename.split("/")[0]
chat_image_dir = Path(os.getcwd()) / "static" / "chat" / "images"
if not (chat_image_dir / sub_dir).exists():
os.makedirs(chat_image_dir / sub_dir)
async with aiofiles.open(chat_image_dir / file_path, "wb") as f:
await f.write(bytearray(file_data))
temp_image_dir = Path(os.getcwd()) / "static" / "temp"
if not (temp_image_dir / sub_dir).exists():
os.makedirs(temp_image_dir / sub_dir)
if total_chunk_num == 1:
async with aiofiles.open(chat_image_dir / filename, "wb") as file:
await file.write(bytearray(msg["bytes"]))
return
if not (temp_image_dir.parent / temp_filename).exists():
try:
async with aiofiles.open(
temp_image_dir.parent / temp_filename, "wb"
) as f:
await f.write(bytearray(msg["bytes"]))
except Exception as e:
print(f"write temp file fail with error: {e}")
temp_file_path_exist: list[bool] = [
(temp_image_dir / f"{filename}-{total_chunk_num}-{i}").exists()
for i in range(total_chunk_num)
]
# assemble the file when all chunk is arrived
if all(temp_file_path_exist):
async with aiofiles.open(chat_image_dir / filename, "ab+") as file:
for i in range(total_chunk_num):
async with aiofiles.open(
temp_image_dir / f"{filename}-{total_chunk_num}-{i}", "rb"
) as chunk_file:
bytes_content = await chunk_file.read()
await file.write(bytes_content)
os.remove(temp_image_dir / f"{filename}-{total_chunk_num}-{i}")
case "video":
pass
def read_chat_file(file_path: str) -> list[int]:
def read_chat_file(filename: str) -> Tuple[int, int, list[int]]:
file_suffix: str = filename.split(".")[1]
if file_suffix == "png":
file_type = "images"
else:
file_type = "videos"
p = Path(os.getcwd()) / "static" / "chat" / file_type / filename
file_size = p.stat().st_size
total_chunk_num = math.ceil(file_size / CHUNK_SIZE)
if not p.exists():
yield -1, 0, []
else:
file = open(p, "rb")
for i in range(total_chunk_num):
byte_data = file.read(CHUNK_SIZE)
yield i, total_chunk_num, array.array("B", byte_data).tolist()
file.close()
async def async_read_chat_file(file_path: str) -> Tuple[int, int, list[int]]:
file_suffix: str = file_path.split(".")[1]
if file_suffix == "png":
@ -99,10 +159,13 @@ def read_chat_file(file_path: str) -> list[int]:
file_type = "videos"
p = Path(os.getcwd()) / "static" / "chat" / file_type / file_path
file_size = p.stat().st_size
total_chunk_num = math.ceil(file_size / CHUNK_SIZE)
if not p.exists():
return []
yield -1, 0, []
else:
with open(p, "rb") as f:
byte_data = f.read()
return array.array("B", byte_data).tolist()
async with aiofiles.open(p, "rb") as f:
for i in range(total_chunk_num):
byte_data = await f.read(CHUNK_SIZE)
yield i, total_chunk_num, array.array("B", byte_data).tolist()

View File

@ -30,6 +30,7 @@ class WebSocketManager:
if self.active_socket.get(another):
socket = self.active_socket.get(another)
await socket.send_json(msg)
else:
if msg["event"] == "friend-chat-msg":
await insert_unreceived_msg(
@ -43,10 +44,26 @@ class WebSocketManager:
msg["dateTime"],
msg["isShowTime"],
)
elif msg["event"] == "chat-image":
await write_chat_file(msg["filename"], msg["bytes"], "image")
if msg["event"] == "chat-image":
await write_chat_file(msg, "image")
if self.active_socket.get(msg["senderId"]):
ws = self.active_socket.get(msg["senderId"])
await ws.send_json(
{
"event": "chat-image-send-ok",
"chatType": 0,
"receiverId": another,
"currentChunkNum": msg["currentChunkNum"],
"totalChunkNum": msg["totalChunkNum"],
"filename": msg["filename"],
}
)
async def broadcast(self, msg: dict, receiver_ids: list[str]):
async def broadcast(
self,
receiver_ids: list[str],
msg: dict,
):
for receiver_id in receiver_ids:
socket = self.active_socket.get(receiver_id)
if socket:
@ -62,10 +79,22 @@ class WebSocketManager:
msg["text"],
msg["attachments"],
msg["dateTime"],
msg["is_show_time"],
msg["isShowTime"],
group_chat_id=msg["groupChatId"],
name=msg["name"],
nickname=msg["nickname"],
remark_in_group_chat=msg["remarkInGroupChat"],
avatar=msg["avatar"],
)
elif msg["event"] == "chat-image":
await write_chat_file(msg["filename"], msg["bytes"], "image")
if msg["event"] == "chat-image":
await write_chat_file(msg, "image")
if self.active_socket.get(msg["senderId"]):
await self.active_socket["senderId"].send_json(
{
"event": "chat-image-send-ok",
"chatType": 1,
"receiverIds": receiver_ids,
"currentChunkNum": msg["currentChunkNum"],
"totalChunkNum": msg["totalChunkNum"],
"filename": msg["filename"],
}
)