From 5e104641d2c5fb7af018ddabc183d4713ce00e3d Mon Sep 17 00:00:00 2001 From: htylight Date: Tue, 1 Aug 2023 10:04:41 +0800 Subject: [PATCH] implement friend group, remark and managing group --- ...88_add_default_group_column_of_contact_.py | 28 +++++ src/crud/contact_crud.py | 111 ++++++++++++++++++ src/crud/multitable_crud.py | 43 ------- src/crud/user_crud.py | 6 +- src/database/models.py | 10 ++ src/main.py | 15 +-- src/response_models/apply_response.py | 22 ++++ src/response_models/contact_response.py | 34 ++++++ src/response_models/user_response.py | 6 +- src/routers/apply.py | 18 +-- src/routers/contact.py | 64 +++++++++- src/routers/signin.py | 56 ++++++--- 12 files changed, 332 insertions(+), 81 deletions(-) create mode 100644 migrations/versions/86195fe53b88_add_default_group_column_of_contact_.py create mode 100644 src/crud/contact_crud.py delete mode 100644 src/crud/multitable_crud.py create mode 100644 src/response_models/apply_response.py create mode 100644 src/response_models/contact_response.py diff --git a/migrations/versions/86195fe53b88_add_default_group_column_of_contact_.py b/migrations/versions/86195fe53b88_add_default_group_column_of_contact_.py new file mode 100644 index 0000000..e57c2c7 --- /dev/null +++ b/migrations/versions/86195fe53b88_add_default_group_column_of_contact_.py @@ -0,0 +1,28 @@ +"""add default_group column of contact table + +Revision ID: 86195fe53b88 +Revises: d0c7f4dd4894 +Create Date: 2023-07-31 00:27:37.445472 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '86195fe53b88' +down_revision = 'd0c7f4dd4894' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('contact', sa.Column('default_group', sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('contact', 'default_group') + # ### end Alembic commands ### diff --git a/src/crud/contact_crud.py b/src/crud/contact_crud.py new file mode 100644 index 0000000..3e781ca --- /dev/null +++ b/src/crud/contact_crud.py @@ -0,0 +1,111 @@ +from typing import Tuple + +from sqlalchemy import select, delete, or_ +from sqlalchemy import ScalarResult, Result +from sqlalchemy.orm.attributes import flag_modified + +from ..database.db import async_session +from ..database.models import Contact, Apply, UserAccount, UserProfile + + +async def insert_contact_friend( + relation: int, + applicant: str, + recipient: str, + applicant_setting: dict, + recipient_setting: dict, +): + session = async_session() + try: + await session.execute( + delete(Apply).where( + Apply.recipient == recipient, + Apply.applicant == applicant, + Apply.relation == relation, + ) + ) + res: ScalarResult[Contact] = await session.scalars( + select(Contact).where( + or_(Contact.user_id == applicant, Contact.user_id == recipient) + ) + ) + for contact in res.all(): + if contact.user_id == recipient: + contact.friends[applicant] = recipient_setting + flag_modified(contact, "friends") + else: + contact.friends[recipient] = applicant_setting + flag_modified(contact, "friends") + + session.add_all(res.all()) + await session.commit() + except Exception: + raise Exception + finally: + await session.close() + + +async def select_contact_all(user_id: str) -> Contact: + session = async_session() + res: ScalarResult[Contact] = await session.scalars( + select(Contact).where(Contact.user_id == user_id) + ) + await session.close() + return res.one() + + +async def select_friends_group_chats( + friend_ids: list[str], +) -> list[Tuple[UserAccount, UserProfile]]: + session = async_session() + res: Result[list[Tuple[UserAccount, UserProfile]]] = await session.execute( + select(UserAccount, UserProfile) + .join(UserAccount.profile) + .where(UserAccount.id.in_(friend_ids)) + ) + + return res.all() + + +async def update_friend_setting( + user_id: str, + friend_id: str, + remark: str | None, + group: str | None, +): + session = async_session() + res = await session.scalars(select(Contact).where(Contact.user_id == user_id)) + contact: Contact = res.one() + if remark: + contact.friends[friend_id]["friendRemark"] = remark + else: + contact.friends[friend_id]["friendGroup"] = group + flag_modified(contact, "friends") + session.add(contact) + await session.commit() + await session.close() + + +async def update_groups( + user_id: str, + groups: list[str], + group_name_change_pair: list[list[str]], + default_group: str, +): + session = async_session() + res = await session.scalars(select(Contact).where(Contact.user_id == user_id)) + contact = res.one() + contact.friend_groups = groups + contact.default_group = default_group + + for pair in group_name_change_pair: + if pair[1] == "": + continue + for friend_id, friend_setting in contact.friends.items(): + if pair[0] == friend_setting["friendGroup"]: + contact.friends[friend_id]["friendGroup"] = pair[1] + + flag_modified(contact, "friends") + session.add(contact) + await session.commit() + await session.close() diff --git a/src/crud/multitable_crud.py b/src/crud/multitable_crud.py deleted file mode 100644 index 78b0959..0000000 --- a/src/crud/multitable_crud.py +++ /dev/null @@ -1,43 +0,0 @@ -from sqlalchemy import select, delete, update -from sqlalchemy import Result, ScalarResult, or_ -from sqlalchemy.orm.attributes import flag_modified - -from ..database.db import async_session -from ..database.models import * - - -async def insert_contact_friend( - relation: int, - applicant: str, - recipient: str, - applicant_setting: dict, - recipient_setting: dict, -): - session = async_session() - try: - await session.execute( - delete(Apply).where( - Apply.recipient == recipient, - Apply.applicant == applicant, - Apply.relation == relation, - ) - ) - res: ScalarResult[Contact] = await session.scalars( - select(Contact).where( - or_(Contact.user_id == applicant, Contact.user_id == recipient) - ) - ) - for contact in res.all(): - if contact.user_id == recipient: - contact.friends[applicant] = recipient_setting - flag_modified(contact, "friends") - else: - contact.friends[recipient] = applicant_setting - flag_modified(contact, "friends") - - session.add_all(res.all()) - await session.commit() - except Exception: - raise Exception - finally: - await session.close() diff --git a/src/crud/user_crud.py b/src/crud/user_crud.py index b757917..559ebe5 100755 --- a/src/crud/user_crud.py +++ b/src/crud/user_crud.py @@ -16,7 +16,11 @@ async def insert_user(username: str, password: str, email: str): id = ulid.new().str user = UserAccount(id=id, username=username, password=password, email=email) profile = UserProfile(nickname=username) - contact = Contact(friends={id: {}}, friend_groups=["我的好友"], group_chats={}) + contact = Contact( + friends={id: {"friendGroup": "我的好友", "friendRemark": ""}}, + friend_groups=["我的好友"], + group_chats={}, + ) user.profile = profile user.contact = contact session.add(user) diff --git a/src/database/models.py b/src/database/models.py index ee91096..504b439 100755 --- a/src/database/models.py +++ b/src/database/models.py @@ -81,6 +81,7 @@ class Contact(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) friends: Mapped[dict[str, FriendSetting]] = mapped_column(JSONB, nullable=True) friend_groups: Mapped[list[str]] = mapped_column(ARRAY(String), nullable=True) + default_group: Mapped[str] = mapped_column(String, default="我的好友", nullable=True) group_chats: Mapped[dict[str, GroupChatSetting]] = mapped_column( JSONB, nullable=True ) @@ -95,10 +96,19 @@ class Contact(Base): f"user={self.user_id}, " f"friends={self.friends}, " f"friend_group={self.friend_groups}, " + f"default_group={self.default_group}," f"group_chats={self.group_chats}" f")" ) + def to_dict(self): + return { + "friends": self.friends, + "friendGroups": self.friend_groups, + "defaultGroup": self.default_group, + "groupChats": self.group_chats, + } + class Apply(Base): __tablename__ = "apply" diff --git a/src/main.py b/src/main.py index a0704b9..61e481f 100755 --- a/src/main.py +++ b/src/main.py @@ -1,5 +1,5 @@ from fastapi import FastAPI, Depends -from starlette.responses import FileResponse +from fastapi.staticfiles import StaticFiles from .dependencies import verify_token from .utils.email_code import smtp @@ -9,8 +9,8 @@ from .routers.user_profile import router as user_profile_router from .routers.user_account import router as user_account_router from .routers.search import router as search_router from .routers.apply import router as apply_router +from .routers.contact import router as contact_router -from .utils.static_file import create_zip_file app = FastAPI() app.include_router(signup_router) @@ -19,6 +19,9 @@ app.include_router(user_profile_router, dependencies=[Depends(verify_token)]) app.include_router(user_account_router, dependencies=[Depends(verify_token)]) app.include_router(search_router, dependencies=[Depends(verify_token)]) app.include_router(apply_router, dependencies=[Depends(verify_token)]) +app.include_router(contact_router, dependencies=[Depends(verify_token)]) + +app.mount("/static", StaticFiles(directory="static"), name="static") @app.on_event("shutdown") @@ -29,11 +32,3 @@ def close_smtp(): @app.get("/") async def main(): return {"code": 10000, "msg": "hello world"} - - -@app.get("/zipfile") -async def get_zipfile(): - file = create_zip_file( - ["luhptjjk1688921163.png", "pnjdvldw1688921358.png"], "avatars" - ) - return FileResponse(file) diff --git a/src/response_models/apply_response.py b/src/response_models/apply_response.py new file mode 100644 index 0000000..19ed844 --- /dev/null +++ b/src/response_models/apply_response.py @@ -0,0 +1,22 @@ +from pydantic import BaseModel + +from .base import BaseResponseModel +from .user_response import _UserAccountProfile + + +class _Apply(BaseModel): + relation: int + applicant: str + recipient: str + groupChatId: str | None = None + hello: str + setting: dict + createdAt: str + + +class ApplyListResponse(BaseResponseModel): + data: list[_Apply] | None = None + + +class ApplicantProfilesResponse(BaseResponseModel): + data: dict[str, _UserAccountProfile] diff --git a/src/response_models/contact_response.py b/src/response_models/contact_response.py new file mode 100644 index 0000000..f13d981 --- /dev/null +++ b/src/response_models/contact_response.py @@ -0,0 +1,34 @@ +from pydantic import BaseModel + +from .base import BaseResponseModel +from .user_response import _UserAccountProfile + + +class _FriendSetting(BaseModel): + friendRemark: str | None = None + friendGroup: str | None = None + + +class _GroupChatSetting(BaseModel): + groupChatRemark: str | None = None + myRemark: str | None = None + + +class _ContactResponseData(BaseModel): + friends: dict[str, _FriendSetting] + friendGroups: list[str] + defaultGroup: str + groupChats: dict[str, _GroupChatSetting] + + +class _ContactAccountProfile(BaseModel): + friends: dict[str, _UserAccountProfile] + groupChats: dict | None = None + + +class ContactResponse(BaseResponseModel): + data: _ContactResponseData | None = None + + +class ContactAccountProfileResponse(BaseResponseModel): + data: _ContactAccountProfile | None = None diff --git a/src/response_models/user_response.py b/src/response_models/user_response.py index 6f0ab9f..c4cf513 100755 --- a/src/response_models/user_response.py +++ b/src/response_models/user_response.py @@ -22,7 +22,7 @@ class _UserProfile(BaseModel): class UserAccountResponse(BaseResponseModel): - data: Optional[_UserAccount] = None + data: _UserAccount | None = None class TokenCreationResponse(BaseResponseModel): @@ -30,8 +30,8 @@ class TokenCreationResponse(BaseResponseModel): class TokenSigninResponse(BaseResponseModel): - data: Optional[_UserAccount] = None - token: Optional[str] = None + data: _UserAccount | None = None + token: str | None = None class UserProfileResponse(BaseResponseModel): diff --git a/src/routers/apply.py b/src/routers/apply.py index c4896df..7f8b1b4 100644 --- a/src/routers/apply.py +++ b/src/routers/apply.py @@ -2,8 +2,12 @@ from fastapi import APIRouter, Query from fastapi.responses import FileResponse from pydantic import BaseModel -from src.crud import apply_crud, user_crud, multitable_crud -from ..response_models.base import BaseResponseModel +from src.crud import apply_crud, user_crud, contact_crud +from ..response_models.apply_response import ( + BaseResponseModel, + ApplyListResponse, + ApplicantProfilesResponse, +) from ..utils.static_file import create_zip_file router = APIRouter(prefix="/apply", tags=["apply"]) @@ -39,7 +43,7 @@ async def apply_friend(apply_info: ApplyInfo): return {"code": 10600, "msg": "Apply Friend Successfully"} -@router.get("/list") +@router.get("/list", response_model=ApplyListResponse) async def get_apply_list(recipient: str): res = await apply_crud.select_apply_all(recipient) if not res: @@ -56,7 +60,7 @@ async def get_apply_list(recipient: str): } -@router.get("/applicant_profiles") +@router.get("/applicant_profiles", response_model=ApplicantProfilesResponse) async def get_applicant_profiles(applicant_ids: list[str] = Query(default=None)): res = await user_crud.select_multiuser_info(applicant_ids) applicant_profiles = {} @@ -77,17 +81,17 @@ async def download_applicant_avatars(avatars: list[str] = Query(default=None)): return FileResponse(file_path) -@router.post("/accept") +@router.post("/accept", response_model=BaseResponseModel) async def accept_apply(accept_info: AcceptInfo): try: - await multitable_crud.insert_contact_friend(**accept_info.model_dump()) + await contact_crud.insert_contact_friend(**accept_info.model_dump()) return {"code": 10600, "msg": "Add Friend Successfully"} except Exception as e: print(f"接受添加好友请求出错....: {e}") return {"code": 10601, "msg": "Something Went Wrong On the Server"} -@router.post("/refuse") +@router.post("/refuse", response_model=BaseResponseModel) async def refuse_apply(refuse_info: RefuseInfo): await apply_crud.delete_apply(**refuse_info.model_dump()) return {"code": 10600, "msg": "Refuse Apply Successfully"} diff --git a/src/routers/contact.py b/src/routers/contact.py index d70e383..718ffde 100755 --- a/src/routers/contact.py +++ b/src/routers/contact.py @@ -1,4 +1,66 @@ from fastapi import APIRouter +from pydantic import BaseModel -router = APIRouter(prefix='/profile', tags=['profile']) +from ..crud import contact_crud, user_crud +from ..response_models.contact_response import ( + BaseResponseModel, + ContactResponse, + ContactAccountProfileResponse, +) +router = APIRouter(prefix="/contact", tags=["contact"]) + + +class ContactIds(BaseModel): + friend_ids: list[str] + group_chat_ids: list[str] | None = None + + +class ChangeFriendSetting(BaseModel): + user_id: str + friend_id: str + remark: str | None = None + group: str | None = None + + +class ManageGroups(BaseModel): + user_id: str + groups: list[str] + group_name_change_pair: list[list[str]] + default_group: str + + +@router.get("", response_model=ContactResponse) +async def get_contact(id: str): + contact = await contact_crud.select_contact_all(id) + print(contact.to_dict()) + return {"code": 10700, "msg": "Get Contact Successfully", "data": contact.to_dict()} + + +@router.post("/profiles", response_model=ContactAccountProfileResponse) +async def get_contact_account_profiles(contact_ids: ContactIds): + res = await contact_crud.select_friends_group_chats(contact_ids.friend_ids) + + friends_account_profiles = {} + + for account, profile in res: + friends_account_profiles[account.id] = account.to_dict() + friends_account_profiles[account.id].update(profile.to_dict()) + + return { + "code": 10700, + "msg": "Get Contact Profiles Successfully", + "data": {"friends": friends_account_profiles}, + } + + +@router.post("/change/friend_setting", response_model=BaseResponseModel) +async def change_friend_remark(friend_remark: ChangeFriendSetting): + await contact_crud.update_friend_setting(**friend_remark.model_dump()) + return {"code": 10700, "msg": "change Friend Remark Successfully"} + + +@router.post("/manage_groups", response_model=BaseResponseModel) +async def manage_groups(group_info: ManageGroups): + await contact_crud.update_groups(**group_info.model_dump()) + return {"code": 10700, "msg": "Manage Groups Successfully"} diff --git a/src/routers/signin.py b/src/routers/signin.py index a51bbf2..2f38d83 100755 --- a/src/routers/signin.py +++ b/src/routers/signin.py @@ -7,10 +7,18 @@ from jose import ExpiredSignatureError, JWTError from ..crud.user_crud import select_account_by from ..utils.password import verify_password -from ..utils.token_handler import create_signin_token, oauth2_scheme, verify_signin_token -from ..response_models.user_response import UserAccountResponse, TokenCreationResponse, TokenSigninResponse +from ..utils.token_handler import ( + create_signin_token, + oauth2_scheme, + verify_signin_token, +) +from ..response_models.user_response import ( + UserAccountResponse, + TokenCreationResponse, + TokenSigninResponse, +) -router = APIRouter(prefix='/signin', tags=['signin']) +router = APIRouter(prefix="/signin", tags=["signin"]) class TokenPayload(BaseModel): @@ -18,39 +26,55 @@ class TokenPayload(BaseModel): device_id: str -@router.post('/username', response_model=UserAccountResponse) +@router.post("/username", response_model=UserAccountResponse) async def signin_by_username(form_data: OAuth2PasswordRequestForm = Depends()): username = form_data.username password = form_data.password - is_existence, user = await select_account_by('username', username) + is_existence, user = await select_account_by("username", username) if not is_existence: - return {'code': 10201, 'msg': 'Username or Password Is Incorrect'} + return {"code": 10201, "msg": "Username or Password Is Incorrect"} is_correct = verify_password(password, user.password) if is_correct: - return {'code': 10200, 'msg': 'Sign in Successfully', 'data': user.to_dict()} + return {"code": 10200, "msg": "Sign in Successfully", "data": user.to_dict()} else: - return {'code': 10201, 'msg': 'Username or Password Is Incorrect', 'data': None} + return {"code": 10201, "msg": "Username or Password Is Incorrect"} -@router.post('/token', response_model=TokenCreationResponse) +@router.post("/token", response_model=TokenCreationResponse) async def create_token(token_payload: TokenPayload): token = create_signin_token(**token_payload.model_dump()) - return {'code': 10200, 'msg': 'Create Token Successfully', 'token': token} + return {"code": 10200, "msg": "Create Token Successfully", "token": token} -@router.get('/token', response_model=TokenSigninResponse) +@router.get("/token", response_model=TokenSigninResponse) async def signin_by_token(token: str = Depends(oauth2_scheme)): try: new_token, id = verify_signin_token(token) - _, user = await select_account_by('id', id) + _, user = await select_account_by("id", id) if new_token: - return {'code': 10200, 'msg': 'Sign in Successfully', 'data': user.to_dict(), 'token': new_token} + return { + "code": 10200, + "msg": "Sign in Successfully", + "data": user.to_dict(), + "token": new_token, + } else: - return {'code': 10200, 'msg': 'Sign in Successfully', 'data': user.to_dict(), 'token': token} + return { + "code": 10200, + "msg": "Sign in Successfully", + "data": user.to_dict(), + "token": token, + } except ExpiredSignatureError: - return {'code': 9999, 'msg': 'Token has Expired', 'data': None, 'token': None} + return { + "code": 9999, + "msg": "Token has Expired", + } except JWTError: - return {'code': 9998, 'msg': 'Token Is Not Right', 'data': None, 'token': None} \ No newline at end of file + return { + "code": 9998, + "msg": "Token Is Not Right", + }