From 4f95b58871a22dc7c1dc78ab4bbd4b4d30ef8c72 Mon Sep 17 00:00:00 2001 From: htylight Date: Wed, 23 Aug 2023 16:32:26 +0800 Subject: [PATCH] create group chat and send group chat message v1 --- .../4947792c7572_create_group_chat_table.py | 40 +++++++++++++ src/crud/contact_crud.py | 20 +++++-- src/crud/group_chat_crud.py | 60 +++++++++++++++++++ src/crud/user_crud.py | 3 +- src/database/db.py | 3 +- src/database/json_typeddict.py | 2 +- src/database/models.py | 40 +++++++++++++ src/main.py | 46 +++++++++++--- src/response_models/group_chat_response.py | 29 +++++++++ src/routers/contact.py | 19 ++++-- src/routers/group_chat.py | 52 ++++++++++++++++ src/routers/message.py | 10 ---- src/utils/email_code.py | 32 +++++----- src/utils/web_socket.py | 12 +++- 14 files changed, 319 insertions(+), 49 deletions(-) create mode 100644 migrations/versions/4947792c7572_create_group_chat_table.py create mode 100644 src/crud/group_chat_crud.py create mode 100644 src/response_models/group_chat_response.py create mode 100644 src/routers/group_chat.py delete mode 100644 src/routers/message.py diff --git a/migrations/versions/4947792c7572_create_group_chat_table.py b/migrations/versions/4947792c7572_create_group_chat_table.py new file mode 100644 index 0000000..94020a8 --- /dev/null +++ b/migrations/versions/4947792c7572_create_group_chat_table.py @@ -0,0 +1,40 @@ +"""create group_chat table + +Revision ID: 4947792c7572 +Revises: 86195fe53b88 +Create Date: 2023-08-22 11:20:25.259711 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '4947792c7572' +down_revision = '86195fe53b88' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('group_chat', + sa.Column('id', sa.String(length=11), nullable=False), + sa.Column('name', sa.String(), nullable=False), + sa.Column('supervisor', sa.String(length=26), nullable=False), + sa.Column('administrators', sa.ARRAY(sa.String(length=26)), nullable=False), + sa.Column('members', sa.ARRAY(sa.String(length=26)), nullable=False), + sa.Column('introduction', sa.String(length=100), nullable=False), + sa.Column('tags', sa.ARRAY(sa.String()), nullable=False), + sa.Column('noticeboard', postgresql.JSONB(astext_type=sa.Text()), nullable=False), + sa.Column('avatar', sa.String(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('group_chat') + # ### end Alembic commands ### diff --git a/src/crud/contact_crud.py b/src/crud/contact_crud.py index 6762087..1bdf7d0 100644 --- a/src/crud/contact_crud.py +++ b/src/crud/contact_crud.py @@ -5,7 +5,7 @@ from sqlalchemy import ScalarResult, Result from sqlalchemy.orm.attributes import flag_modified from ..database.db import async_session -from ..database.models import Contact, Apply, UserAccount, UserProfile +from ..database.models import Contact, Apply, UserAccount, UserProfile, GroupChat async def insert_contact_friend( @@ -57,15 +57,23 @@ async def select_contact_all(user_id: str) -> Contact: async def select_friends_group_chats( friend_ids: list[str], -) -> list[Tuple[UserAccount, UserProfile]]: + group_chat_ids: list[str], +) -> Tuple[list[Tuple[UserAccount, UserProfile]], list[GroupChat]]: session = async_session() - res: Result[list[Tuple[UserAccount, UserProfile]]] = await session.execute( + friends_res: Result[list[Tuple[UserAccount, UserProfile]]] = await session.execute( select(UserAccount, UserProfile) .join(UserAccount.profile) .where(UserAccount.id.in_(friend_ids)) ) - - return res.all() + if group_chat_ids: + group_chats_res = await session.scalars( + select(GroupChat).where(GroupChat.id.in_(group_chat_ids)) + ) + await session.close() + return friends_res.all(), group_chats_res.all() + else: + await session.close() + return friends_res.all(), [] async def update_friend_setting( @@ -106,7 +114,7 @@ async def update_groups( if pair[1] == "" and deleted_group == "": continue for friend_id, friend_setting in contact.friends.items(): - if pair[0] == friend_setting["friendGroup"]: + if pair[0] == friend_setting["friendGroup"] and pair[1] != "": contact.friends[friend_id]["friendGroup"] = pair[1] if friend_setting["friendGroup"] == deleted_group: contact.friends[friend_id]["friendGroup"] = default_group diff --git a/src/crud/group_chat_crud.py b/src/crud/group_chat_crud.py new file mode 100644 index 0000000..f9e71d3 --- /dev/null +++ b/src/crud/group_chat_crud.py @@ -0,0 +1,60 @@ +import random + +from sqlalchemy import select, insert, ScalarResult +from sqlalchemy.orm.attributes import flag_modified + +from ..database.db import async_session +from ..database.models import Contact, GroupChat, UserProfile + +pool = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"] + + +def _create_random_id() -> str: + random_chars = random.choices(pool, k=11) + return "".join(random_chars) + + +async def insert_group_chat(supervisor: str, members: list[str]) -> GroupChat: + id = _create_random_id() + name = f"群聊 ({id})" + session = async_session() + try: + group_chat_res = await session.scalars( + insert(GroupChat) + .values(id=id, name=name, supervisor=supervisor, members=members) + .returning(GroupChat) + ) + contact_res: ScalarResult[Contact] = await session.scalars( + select(Contact).where(Contact.user_id.in_(members)) + ) + for contact in contact_res.all(): + contact.group_chats[id] = {"nameRemark": "", "myRemark": ""} + flag_modified(contact, "group_chats") + + session.add_all(contact_res.all()) + await session.commit() + await session.close() + return group_chat_res.one() + except Exception as e: + await session.close() + raise e + + +async def select_member_name_avatar(member_id: str, is_friend: bool): + session = async_session() + if is_friend: + res: ScalarResult[Contact] = await session.scalars( + select(Contact.group_chats).where(Contact.user_id == member_id) + ) + await session.close() + # {'81906574618': {'myRemark': '', 'nameRemark': ''}} + return res.one() + else: + res = await session.execute( + select(UserProfile.nickname, UserProfile.avatar, Contact.group_chats) + .join(Contact, UserProfile.user_id == Contact.user_id) + .where(UserProfile.user_id == member_id) + ) + await session.close() + # ('htylight', 'cznowoyn1692502503.png', {'81906574618': {'myRemark': '', 'nameRemark': ''}}) + return res.one() diff --git a/src/crud/user_crud.py b/src/crud/user_crud.py index 559ebe5..f7f9fdf 100755 --- a/src/crud/user_crud.py +++ b/src/crud/user_crud.py @@ -47,8 +47,9 @@ async def select_account_by( select(UserAccount).where(UserAccount.id == value) ) - user = res.first() + await session.close() + user = res.first() return (True, user) if user else (False, None) diff --git a/src/database/db.py b/src/database/db.py index fc2a3dd..fd0c41b 100755 --- a/src/database/db.py +++ b/src/database/db.py @@ -1,6 +1,5 @@ from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine db_url = "postgresql+asyncpg://together:togetherno.1@localhost/together" -engine = create_async_engine(db_url, echo=True) +engine = create_async_engine(db_url) async_session = async_sessionmaker(engine, expire_on_commit=False) - diff --git a/src/database/json_typeddict.py b/src/database/json_typeddict.py index e323599..ce9af44 100644 --- a/src/database/json_typeddict.py +++ b/src/database/json_typeddict.py @@ -7,5 +7,5 @@ class FriendSetting(TypedDict): class GroupChatSetting(TypedDict): - groupChatRemark: str | None + nameRemark: str | None myRemark: str | None diff --git a/src/database/models.py b/src/database/models.py index 504b439..d8f87d8 100755 --- a/src/database/models.py +++ b/src/database/models.py @@ -134,3 +134,43 @@ class Apply(Base): "setting": self.setting, "createdAt": self.created_at.strftime("%y-%m-%d %H:%M:%S"), } + + +class GroupChat(Base): + __tablename__ = "group_chat" + id: Mapped[str] = mapped_column(String(11), primary_key=True) + name: Mapped[str] = mapped_column(String) + supervisor: Mapped[str] = mapped_column(String(26)) + administrators: Mapped[list[str]] = mapped_column(ARRAY(String(26)), default=[]) + members: Mapped[list[str]] = mapped_column(ARRAY(String(26))) + introduction: Mapped[str] = mapped_column(String(100), default="") + tags: Mapped[list[str]] = mapped_column(ARRAY(String), default=[]) + noticeboard: Mapped[list[dict]] = mapped_column(JSONB, default=[]) + avatar: Mapped[str] = mapped_column(String, default="") + created_at: Mapped[datetime] = mapped_column( + default=datetime.now, onupdate=datetime.now + ) + + def to_dict(self): + return { + "id": self.id, + "name": self.name, + "supervisor": self.supervisor, + "administrators": self.administrators, + "members": self.members, + "introduction": self.introduction, + "noticeboard": self.noticeboard, + "tags": self.tags, + "avatar": self.avatar, + } + + +# class UnreceivedMsg(Base): +# __tablename__ = "unreceived_msg" +# id: Mapped[int] = mapped_column(Integer, primary_key=True) +# receiver_id: Mapped[str] = mapped_column(String(26)) +# sender_id: Mapped[str] = mapped_column(String(26)) +# group_chat_id: Mapped[str] = mapped_column(String, nullable=True) +# type: Mapped[str] = mapped_column(String) +# text: Mapped[str] = mapped_column(String) +# attachments: Mapped[list[str]] = mapped_column(ARRAY(String)) diff --git a/src/main.py b/src/main.py index 432a322..8567db1 100755 --- a/src/main.py +++ b/src/main.py @@ -1,9 +1,9 @@ -from fastapi import FastAPI, Depends, WebSocket +from fastapi import FastAPI, Depends, WebSocket, WebSocketDisconnect from fastapi.staticfiles import StaticFiles from .dependencies import verify_token from .utils.email_code import smtp -from .utils.web_socket import manager +from .utils.web_socket import WebSocketManager from .routers.signin import router as signin_router from .routers.signup import router as signup_router from .routers.user_profile import router as user_profile_router @@ -11,7 +11,7 @@ from .routers.user_account import router as user_account_router from .routers.search import router as search_router from .routers.apply import router as apply_router from .routers.contact import router as contact_router -from .routers.message import router as message_router +from .routers.group_chat import router as group_chat_router app = FastAPI() @@ -22,9 +22,7 @@ app.include_router(user_account_router, dependencies=[Depends(verify_token)]) app.include_router(search_router, dependencies=[Depends(verify_token)]) app.include_router(apply_router, dependencies=[Depends(verify_token)]) app.include_router(contact_router, dependencies=[Depends(verify_token)]) -app.include_router( - message_router, -) +app.include_router(group_chat_router) app.mount("/static", StaticFiles(directory="static"), name="static") @@ -39,6 +37,36 @@ async def main(): return {"code": 10000, "msg": "hello world"} -@app.websocket("/ws/connect") -async def connect_websocket(websocket: WebSocket, id: str): - await manager.connect(id, websocket) +ws_manager = WebSocketManager() + + +@app.websocket("/ws/{user_id}") +async def connect_websocket(websocket: WebSocket, user_id: str): + await ws_manager.connect(user_id, websocket) + + try: + while True: + data = await websocket.receive_json() + match data["event"]: + case "ping": + await ws_manager.active_socket[user_id].send_json({"type": "pong"}) + case "friend-chat-msg": + if ws_manager.active_socket.get(data["receiverId"]): + await ws_manager.send_to_another(data["receiverId"], data) + case "apply-friend": + if ws_manager.active_socket.get(data["recipient"]): + await ws_manager.send_to_another(data["recipient"], data) + case "friend-added" | "friend-deleted": + if ws_manager.active_socket.get(data["receiverId"]): + await ws_manager.send_to_another(data["receiverId"], data) + case "friend-chat-image": + if ws_manager.active_socket.get(data["receiverId"]): + await ws_manager.send_to_another(data["receiverId"], data) + case "group-chat-creation" | "group-chat-msg" | "group-chat-image": + receiver_ids = data["receiverIds"] + del data["receiverIds"] + await ws_manager.broadcast(data, receiver_ids) + + except WebSocketDisconnect: + print(f"{user_id} disconnect") + ws_manager.disconnect(user_id) diff --git a/src/response_models/group_chat_response.py b/src/response_models/group_chat_response.py new file mode 100644 index 0000000..e2de176 --- /dev/null +++ b/src/response_models/group_chat_response.py @@ -0,0 +1,29 @@ +from pydantic import BaseModel + +from .base import BaseResponseModel + + +class _GroupChatProfile(BaseModel): + id: str + name: str + supervisor: str + administrators: list[str] + members: list[str] + introduction: str + noticeboard: list[dict] + tags: list[str] + avatar: str + + +class _MemberNameAvatar(BaseModel): + remark: str + nickname: str + avatar: str + + +class GroupChatProfileResponse(BaseResponseModel): + data: _GroupChatProfile | None = None + + +class MemberNameAvatarResponse(BaseResponseModel): + data: _MemberNameAvatar diff --git a/src/routers/contact.py b/src/routers/contact.py index c0e4dfa..f47379d 100755 --- a/src/routers/contact.py +++ b/src/routers/contact.py @@ -2,6 +2,7 @@ from fastapi import APIRouter from pydantic import BaseModel from ..crud import contact_crud, user_crud +from ..database.models import UserAccount, GroupChat from ..response_models.contact_response import ( BaseResponseModel, ContactResponse, @@ -39,24 +40,32 @@ class MyselfFriendId(BaseModel): @router.get("", response_model=ContactResponse) async def get_contact(id: str): contact = await contact_crud.select_contact_all(id) - print(contact.to_dict()) return {"code": 10700, "msg": "Get Contact Successfully", "data": contact.to_dict()} @router.post("/profiles", response_model=ContactAccountProfileResponse) async def get_contact_account_profiles(contact_ids: ContactIds): - res = await contact_crud.select_friends_group_chats(contact_ids.friend_ids) - + friends_res, group_chats_res = await contact_crud.select_friends_group_chats( + **contact_ids.model_dump() + ) friends_account_profiles = {} + group_chats_profiles = {} - for account, profile in res: + for account, profile in friends_res: friends_account_profiles[account.id] = account.to_dict() friends_account_profiles[account.id].update(profile.to_dict()) + if group_chats_res: + for group_chat in group_chats_res: + group_chats_profiles[group_chat.id] = group_chat.to_dict() + return { "code": 10700, "msg": "Get Contact Profiles Successfully", - "data": {"friends": friends_account_profiles}, + "data": { + "friends": friends_account_profiles, + "groupChats": group_chats_profiles, + }, } diff --git a/src/routers/group_chat.py b/src/routers/group_chat.py new file mode 100644 index 0000000..9cb0db8 --- /dev/null +++ b/src/routers/group_chat.py @@ -0,0 +1,52 @@ +from fastapi import APIRouter +from pydantic import BaseModel + +from ..crud import group_chat_crud +from ..response_models.group_chat_response import ( + GroupChatProfileResponse, + MemberNameAvatarResponse, +) + +router = APIRouter(prefix="/group_chat", tags=["group_chat"]) + + +class GroupChatCreate(BaseModel): + supervisor: str + members: list[str] + + +@router.post("/create", response_model=GroupChatProfileResponse) +async def create_group_chat(group_chat_create: GroupChatCreate): + try: + group_chat = await group_chat_crud.insert_group_chat( + **group_chat_create.model_dump() + ) + return { + "code": 10800, + "msg": "Create Group Chat Successfully", + "data": group_chat.to_dict(), + } + + except Exception as e: + print(f"Creating Group Chat fail with error: {e}") + return {"code": 9999, "msg": "Server error"} + + +@router.get("/member_name_avatar", response_model=MemberNameAvatarResponse) +async def get_member_name_avatar(group_chat_id: str, member_id: str, is_friend: bool): + res = await group_chat_crud.select_member_name_avatar(member_id, is_friend) + data = {} + if is_friend: + data["remark"] = res[group_chat_id]["myRemark"] + data["nickname"] = "" + data["avatar"] = "" + else: + data["remark"] = res[2][group_chat_id]["myRemark"] + data["nickname"] = res[0] + data["avatar"] = res[1] + + return { + "code": 10800, + "msg": "Get Group Chat Member Name and Avatar Successfully", + "data": data, + } diff --git a/src/routers/message.py b/src/routers/message.py deleted file mode 100644 index 3207ef5..0000000 --- a/src/routers/message.py +++ /dev/null @@ -1,10 +0,0 @@ -from fastapi import APIRouter, WebSocket - -from ..utils.web_socket import WebSocketManager - -router = APIRouter(prefix="/message", tags=["message"]) - - -@router.websocket("/friend") -async def send_message_to_friend(websocket: WebSocket): - pass diff --git a/src/utils/email_code.py b/src/utils/email_code.py index 8351642..5fba145 100755 --- a/src/utils/email_code.py +++ b/src/utils/email_code.py @@ -4,31 +4,35 @@ from smtplib import SMTP, SMTPServerDisconnected from src.database.redis_api import redis_server -smtp = SMTP(host='smtp.office365.com') +host = "smtp.163.com" +username = "together_app@163.com" +password = "AGBEYAOHPTAHHMKK" + +smtp = SMTP(host=host) smtp.ehlo() smtp.starttls() -smtp.login('together_app@outlook.com', 'togetherno.1') +smtp.login(username, password) def connect_email_server(): try: smtp.noop() except SMTPServerDisconnected: - smtp.connect(host='smtp.office365.com') + smtp.connect(host=host) smtp.ehlo() smtp.starttls() - smtp.login('together_app@outlook.com', 'togetherno.1') + smtp.login(username, password) def send_email(to: str): code = generate_code(to) msg = EmailMessage() connect_email_server() - msg['Subject'] = 'Together app signup verification code' - msg['From'] = 'TogetherApp ' - msg['To'] = f'<{to}>' + msg["Subject"] = "Together app signup verification code" + msg["From"] = "TogetherApp " + msg["To"] = f"<{to}>" - email_content = f'''\ + email_content = f"""\ @@ -37,27 +41,27 @@ def send_email(to: str): - ''' + """ msg.set_content(email_content) smtp.send_message(msg) def generate_code(email: str) -> str: - seed = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] + seed = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"] chosen_elements = random.choices(seed, k=6) - code = ''.join(chosen_elements) - redis_server.set(f'code:{email}', code, ex=60) + code = "".join(chosen_elements) + redis_server.set(f"code:{email}", code, ex=60) return code def verify_code(email: str, code: str) -> bool: - key = f'code:{email}' + key = f"code:{email}" value = redis_server.get(key) return code == value def has_code(email: str) -> bool: - key = f'code:{email}' + key = f"code:{email}" value = redis_server.get(key) return True if value else False diff --git a/src/utils/web_socket.py b/src/utils/web_socket.py index f802ee4..bc7cc22 100644 --- a/src/utils/web_socket.py +++ b/src/utils/web_socket.py @@ -9,5 +9,15 @@ class WebSocketManager: await websocket.accept() self.active_socket[id] = websocket + def disconnect(self, user_id: str): + del self.active_socket[user_id] -manager = WebSocketManager() + async def send_to_another(self, another: str, msg: dict): + socket = self.active_socket.get(another) + await socket.send_json(msg) + + async def broadcast(self, msg: dict, receiver_ids: list[str]): + for receiver_id in receiver_ids: + socket = self.active_socket.get(receiver_id) + if socket: + await socket.send_json(msg)