Verified Commit f960652c authored by Andrey Vertiprahov's avatar Andrey Vertiprahov
Browse files

Fix Pydantic models for MRT Service request.

parent 1c8b40ee
......@@ -6,7 +6,7 @@
# ----------------------------------------------------------------------
# Python modules
from typing import List
from typing import List, Optional, Literal
# Third-party modules
from pydantic import BaseModel
......@@ -26,13 +26,28 @@ from pydantic import BaseModel
# ]
class MRTArgs(BaseModel):
class MRTCommandsArgs(BaseModel):
commands: List[str]
include_commands: str
ignore_cli_errors: str
include_commands: bool = False
ignore_cli_errors: bool = False
class MRTScript(BaseModel):
id: str
script: str
args: MRTArgs
class MRTInterfaceArgs(BaseModel):
interface: str
class MRTInterfaceScript(BaseModel):
id: int
script: Literal["get_mac_address_table"]
args: Optional[MRTInterfaceArgs]
class MRTAnyScript(BaseModel):
id: int
script: Literal["get_version"]
class MRTCommandScript(BaseModel):
id: int
script: Literal["commands"]
args: MRTCommandsArgs
......@@ -9,7 +9,7 @@
# Python modules
import logging
import asyncio
from typing import List
from typing import List, Union
# Third-party modules
import orjson
......@@ -19,7 +19,7 @@ from fastapi.responses import StreamingResponse
# NOC modules
from noc.aaa.models.user import User
from noc.core.service.loader import get_service
from noc.services.mrt.models.mrt import MRTScript
from noc.services.mrt.models.mrt import MRTInterfaceScript, MRTCommandScript, MRTAnyScript
from noc.core.service.deps.user import get_current_user
from noc.core.service.error import RPCRemoteError, RPCError
......@@ -81,11 +81,13 @@ async def _run_script(current_user, oid, script, args, span_id=0, bi_id=None):
return {"id": str(oid), "result": r}
async def _iterdata(req, current_user):
async def _iterdata(
req: List[Union[MRTCommandScript, MRTInterfaceScript, MRTAnyScript]], current_user
):
service = get_service()
metrics["mrt_requests"] += 1
# Object ids
ids = set(int(d.id) for d in req if hasattr(d, "id") and hasattr(d, "script"))
ids = set(int(d.id) for d in req)
logger.info(
"Run task on parralels: %d (Max concurrent %d), for User: %s",
len(req),
......@@ -142,7 +144,10 @@ async def _iterdata(req, current_user):
@router.post("/api/mrt/")
async def api_mrt(req: List[MRTScript], current_user: User = Depends(get_current_user)):
async def api_mrt(
req: List[Union[MRTCommandScript, MRTInterfaceScript, MRTAnyScript]],
current_user: User = Depends(get_current_user),
):
# Disable nginx proxy buffering
headers = {"X-Accel-Buffering": "no"}
return StreamingResponse(_iterdata(req, current_user), media_type="text/html", headers=headers)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment