diff --git a/backend/open_webui/routers/credit.py b/backend/open_webui/routers/credit.py index 4a9569afa4..defbe9d20b 100644 --- a/backend/open_webui/routers/credit.py +++ b/backend/open_webui/routers/credit.py @@ -14,8 +14,9 @@ CreditLogSimpleModel, CreditLogs, ) +from open_webui.models.models import Models, ModelPriceForm from open_webui.models.users import UserModel -from open_webui.utils.auth import get_current_user +from open_webui.utils.auth import get_current_user, get_admin_user from open_webui.utils.credit.ezfp import ezfp_client log = logging.getLogger(__name__) @@ -89,3 +90,27 @@ async def ticket_callback(request: Request) -> str: @router.get("/callback/redirect", response_class=RedirectResponse) async def ticket_callback_redirect() -> RedirectResponse: return RedirectResponse(url=EZFP_CALLBACK_HOST.value, status_code=302) + + +@router.get("/models/price") +async def get_model_price(_: UserModel = Depends(get_admin_user)): + return { + model.id: model.price if model.price else {} + for model in Models.get_all_models() + if model.id + } + + +@router.put("/models/price") +async def update_model_price( + form_data: dict[str, dict], _: UserModel = Depends(get_admin_user) +): + for model_id, price in form_data.items(): + model = Models.get_model_by_id(id=model_id) + if not model: + continue + model.price = ( + ModelPriceForm.model_validate(price).model_dump() if price else None + ) + Models.update_model_by_id(id=model_id, model=model) + return "success"