diff --git a/poetry.lock b/poetry.lock index 426902a..dba666e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2,13 +2,13 @@ [[package]] name = "agentcore" -version = "0.1.1" +version = "0.1.2" description = "A core API for agentsea" optional = false python-versions = "<4.0,>=3.10" files = [ - {file = "agentcore-0.1.1-py3-none-any.whl", hash = "sha256:41b21a35d7ec3153a604c392c48149bea64419fc9b39ba1231bbbf8804624f80"}, - {file = "agentcore-0.1.1.tar.gz", hash = "sha256:01fba442d9abdc449c1d6393247c05611dc6db6c4c3f1988069210a21507c73d"}, + {file = "agentcore-0.1.2-py3-none-any.whl", hash = "sha256:ebf3c2db0f2f95cbf5463e475ce8370176c1e0e8a8d2e12e16eb796d8ccb4010"}, + {file = "agentcore-0.1.2.tar.gz", hash = "sha256:c88e507722b2c98f1d9275a68c35094fb7539be8013dd0f8007e89dd2ee21f19"}, ] [package.dependencies] @@ -16,13 +16,13 @@ pydantic = ">=2.8.2,<3.0.0" [[package]] name = "agentdesk" -version = "0.2.114" +version = "0.2.118" description = "A desktop for AI agents" optional = false python-versions = "<4.0,>=3.10" files = [ - {file = "agentdesk-0.2.114-py3-none-any.whl", hash = "sha256:60253df6078a67abdcef1b88923355fe65200d893a47e75a383508ef6356d886"}, - {file = "agentdesk-0.2.114.tar.gz", hash = "sha256:5958d4f4957ef5c90eb0fec512fe5ae7dcc59f983ce9ca98447959a59d9ff3e2"}, + {file = "agentdesk-0.2.118-py3-none-any.whl", hash = "sha256:b48523744a41485b5d4a6273716459e66c8222ba99f9821fe6b9c2f092f41240"}, + {file = "agentdesk-0.2.118.tar.gz", hash = "sha256:ade6aee1c19296a5f424b7c8d779a2e4b737085480bc2eb88c9b8d174b1d5168"}, ] [package.dependencies] @@ -354,17 +354,17 @@ six = "*" [[package]] name = "boto3" -version = "1.36.0" +version = "1.36.7" description = "The AWS SDK for Python" optional = false python-versions = ">=3.8" files = [ - {file = "boto3-1.36.0-py3-none-any.whl", hash = "sha256:d0ca7a58ce25701a52232cc8df9d87854824f1f2964b929305722ebc7959d5a9"}, - {file = "boto3-1.36.0.tar.gz", hash = "sha256:159898f51c2997a12541c0e02d6e5a8fe2993ddb307b9478fd9a339f98b57e00"}, + {file = "boto3-1.36.7-py3-none-any.whl", hash = "sha256:ab501f75557863e2d2c9fa731e4fe25c45f35e0d92ea0ee11a4eaa63929d3ede"}, + {file = "boto3-1.36.7.tar.gz", hash = "sha256:ae98634efa7b47ced1b0d7342e2940b32639eee913f33ab406590b8ed55ee94b"}, ] [package.dependencies] -botocore = ">=1.36.0,<1.37.0" +botocore = ">=1.36.7,<1.37.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.11.0,<0.12.0" @@ -373,13 +373,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "boto3-stubs" -version = "1.36.0" -description = "Type annotations for boto3 1.36.0 generated with mypy-boto3-builder 8.8.0" +version = "1.36.7" +description = "Type annotations for boto3 1.36.7 generated with mypy-boto3-builder 8.8.0" optional = false python-versions = ">=3.8" files = [ - {file = "boto3_stubs-1.36.0-py3-none-any.whl", hash = "sha256:87376acbbde3aeeea301aefb1e7725be2aa7b9d7ebcc340a8ce0bdcfa10b8e95"}, - {file = "boto3_stubs-1.36.0.tar.gz", hash = "sha256:2f766cb3a10000c08464f4894cce607f65704619c96971970fd49a681d0875b3"}, + {file = "boto3_stubs-1.36.7-py3-none-any.whl", hash = "sha256:d5d3f1f537c4d317f1f11b1cb4ce8f427822204936e29419b43c709ec54758ea"}, + {file = "boto3_stubs-1.36.7.tar.gz", hash = "sha256:197bdbacd3a9085c6310a06f21616f30f6103ed8be67705962620ac4587ba1fb"}, ] [package.dependencies] @@ -436,7 +436,7 @@ bedrock-data-automation-runtime = ["mypy-boto3-bedrock-data-automation-runtime ( bedrock-runtime = ["mypy-boto3-bedrock-runtime (>=1.36.0,<1.37.0)"] billing = ["mypy-boto3-billing (>=1.36.0,<1.37.0)"] billingconductor = ["mypy-boto3-billingconductor (>=1.36.0,<1.37.0)"] -boto3 = ["boto3 (==1.36.0)"] +boto3 = ["boto3 (==1.36.7)"] braket = ["mypy-boto3-braket (>=1.36.0,<1.37.0)"] budgets = ["mypy-boto3-budgets (>=1.36.0,<1.37.0)"] ce = ["mypy-boto3-ce (>=1.36.0,<1.37.0)"] @@ -799,13 +799,13 @@ xray = ["mypy-boto3-xray (>=1.36.0,<1.37.0)"] [[package]] name = "botocore" -version = "1.36.0" +version = "1.36.7" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">=3.8" files = [ - {file = "botocore-1.36.0-py3-none-any.whl", hash = "sha256:b54b11f0cfc47fc1243ada0f7f461266c279968487616720fa8ebb02183917d7"}, - {file = "botocore-1.36.0.tar.gz", hash = "sha256:0232029ff9ae3f5b50cdb25cbd257c16f87402b6d31a05bd6483638ee6434c4b"}, + {file = "botocore-1.36.7-py3-none-any.whl", hash = "sha256:a6c6772d777af2957ac9975207fac1ccc4ce101408b85e9b5e3c5ba0bb949102"}, + {file = "botocore-1.36.7.tar.gz", hash = "sha256:9abc64bde5e7d8f814ea91d6fc0a8142511fc96427c19fe9209677c20a0c9e6e"}, ] [package.dependencies] @@ -818,13 +818,13 @@ crt = ["awscrt (==0.23.4)"] [[package]] name = "botocore-stubs" -version = "1.35.99" +version = "1.36.7" description = "Type annotations and code completion for botocore" optional = false python-versions = ">=3.8" files = [ - {file = "botocore_stubs-1.35.99-py3-none-any.whl", hash = "sha256:6fe3fd4140a16ab3dd560050fdae74f41d4ff7ca55f3583df66b866ef6f9a333"}, - {file = "botocore_stubs-1.35.99.tar.gz", hash = "sha256:659626b30b9950c2d3938f938c76bf15a854b759f96617dc2cf41594f7a4e352"}, + {file = "botocore_stubs-1.36.7-py3-none-any.whl", hash = "sha256:77052e3a86a3f77383c638db63379652bafac3a2b310954392e0cfb3dacd3dad"}, + {file = "botocore_stubs-1.36.7.tar.gz", hash = "sha256:51c51da5379d3e4c4cb7e3dbe8451f572ecbfe6a5ced3a76a6b958941ef72409"}, ] [package.dependencies] @@ -1759,13 +1759,13 @@ protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4 [[package]] name = "google-cloud-compute" -version = "1.23.0" +version = "1.24.0" description = "Google Cloud Compute API client library" optional = false python-versions = ">=3.7" files = [ - {file = "google_cloud_compute-1.23.0-py2.py3-none-any.whl", hash = "sha256:785b94b4bb948494d6a0fa77c2b1cfc4092f08840099b5b4d9caddd599365eb8"}, - {file = "google_cloud_compute-1.23.0.tar.gz", hash = "sha256:2cb02563af528f7460f2c8c905230996873e284fffe69ccfe2e47890f9e5490b"}, + {file = "google_cloud_compute-1.24.0-py2.py3-none-any.whl", hash = "sha256:e50b4827b6d7027eaf480f472f4a0d4b097d38a6d3cc10704bc9135dca69e190"}, + {file = "google_cloud_compute-1.24.0.tar.gz", hash = "sha256:3538830f77cdae6e4ac5d9cee1827180c36a2a4574c4ce4a4ffef7562b23c078"}, ] [package.dependencies] @@ -1793,13 +1793,13 @@ protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4 [[package]] name = "google-cloud-container" -version = "2.55.0" +version = "2.55.1" description = "Google Cloud Container API client library" optional = false python-versions = ">=3.7" files = [ - {file = "google_cloud_container-2.55.0-py2.py3-none-any.whl", hash = "sha256:a442fec5cc20f8bc3cfc1d9c292505da782684389d252d8861376bc2e48c7a03"}, - {file = "google_cloud_container-2.55.0.tar.gz", hash = "sha256:6402401e719bde4eb607b1fe153fdb2931cac28f9c59d912f114a48b2e580398"}, + {file = "google_cloud_container-2.55.1-py2.py3-none-any.whl", hash = "sha256:77e640f3c3b25c5e8b7ac5293d8865b7ca7c5bfd77750dff9fb0b59d0bfb3460"}, + {file = "google_cloud_container-2.55.1.tar.gz", hash = "sha256:0f2c351ef6b0f6e01de9e99193ebba4e7f5debba428e9a0b8074919c97cd7555"}, ] [package.dependencies] @@ -2922,13 +2922,13 @@ files = [ [[package]] name = "mypy-boto3-ec2" -version = "1.36.0" -description = "Type annotations for boto3 EC2 1.36.0 service generated with mypy-boto3-builder 8.8.0" +version = "1.36.5" +description = "Type annotations for boto3 EC2 1.36.5 service generated with mypy-boto3-builder 8.8.0" optional = false python-versions = ">=3.8" files = [ - {file = "mypy_boto3_ec2-1.36.0-py3-none-any.whl", hash = "sha256:0d046f20076d93d7e023f8dfb5427daa6dd612f4aaed214697e70c5c4cc141d5"}, - {file = "mypy_boto3_ec2-1.36.0.tar.gz", hash = "sha256:fdd15717474a105d303d0a21d38758f57d31f09a9ac6014bb4e609a1c070b210"}, + {file = "mypy_boto3_ec2-1.36.5-py3-none-any.whl", hash = "sha256:ffa53e0e0fd0b932b9ba6d38f995442a8008b388a70b2790a053be478d7bcf71"}, + {file = "mypy_boto3_ec2-1.36.5.tar.gz", hash = "sha256:04a7ae6fe63e0ca000313b69b3dee8f673151e970f03e405d74b110081f9544f"}, ] [package.dependencies] @@ -4381,20 +4381,20 @@ files = [ [[package]] name = "s3transfer" -version = "0.11.0" +version = "0.11.2" description = "An Amazon S3 Transfer Manager" optional = false python-versions = ">=3.8" files = [ - {file = "s3transfer-0.11.0-py3-none-any.whl", hash = "sha256:f43b03931c198743569bbfb6a328a53f4b2b4ec723cd7c01fab68e3119db3f8b"}, - {file = "s3transfer-0.11.0.tar.gz", hash = "sha256:6563eda054c33bdebef7cbf309488634651c47270d828e594d151cd289fb7cf7"}, + {file = "s3transfer-0.11.2-py3-none-any.whl", hash = "sha256:be6ecb39fadd986ef1701097771f87e4d2f821f27f6071c872143884d2950fbc"}, + {file = "s3transfer-0.11.2.tar.gz", hash = "sha256:3b39185cb72f5acc77db1a58b6e25b977f28d20496b6e58d6813d75f464d632f"}, ] [package.dependencies] -botocore = ">=1.33.2,<2.0a.0" +botocore = ">=1.36.0,<2.0a.0" [package.extras] -crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"] +crt = ["botocore[crt] (>=1.36.0,<2.0a.0)"] [[package]] name = "shapely" @@ -5032,24 +5032,24 @@ test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6. [[package]] name = "types-awscrt" -version = "0.23.6" +version = "0.23.8" description = "Type annotations and code completion for awscrt" optional = false python-versions = ">=3.8" files = [ - {file = "types_awscrt-0.23.6-py3-none-any.whl", hash = "sha256:fbf9c221af5607b24bf17f8431217ce8b9a27917139edbc984891eb63fd5a593"}, - {file = "types_awscrt-0.23.6.tar.gz", hash = "sha256:405bce8c281f9e7c6c92a229225cc0bf10d30729a6a601123213389bd524b8b1"}, + {file = "types_awscrt-0.23.8-py3-none-any.whl", hash = "sha256:d66b3817565769f5311b7e171a3c48d3dbf8a8f9c22f02686c2f003b6559a2a5"}, + {file = "types_awscrt-0.23.8.tar.gz", hash = "sha256:2141391a8f4d36cf098406c19d9060b34f13a558c22d4aadac250a0c57d12710"}, ] [[package]] name = "types-s3transfer" -version = "0.10.4" +version = "0.11.2" description = "Type annotations and code completion for s3transfer" optional = false python-versions = ">=3.8" files = [ - {file = "types_s3transfer-0.10.4-py3-none-any.whl", hash = "sha256:22ac1aabc98f9d7f2928eb3fb4d5c02bf7435687f0913345a97dd3b84d0c217d"}, - {file = "types_s3transfer-0.10.4.tar.gz", hash = "sha256:03123477e3064c81efe712bf9d372c7c72f2790711431f9baa59cf96ea607267"}, + {file = "types_s3transfer-0.11.2-py3-none-any.whl", hash = "sha256:09c31cff8c79a433fcf703b840b66d1f694a6c70c410ef52015dd4fe07ee0ae2"}, + {file = "types_s3transfer-0.11.2.tar.gz", hash = "sha256:3ccb8b90b14434af2fb0d6c08500596d93f3a83fb804a2bb843d9bf4f7c2ca60"}, ] [[package]] @@ -5586,4 +5586,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "1e9db2fba23c7658ef8639e97b61fe99813c964f377fbac2224b63801f4bab9e" +content-hash = "78a9985e07f35304d384f20bd0b478e32f8101ef3134df824ba43d24af78ae41" diff --git a/pyproject.toml b/pyproject.toml index 4aa2c2a..ab94f81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "surfkit" -version = "0.1.314" +version = "0.1.316" description = "A toolkit for building AI agents that use devices" authors = ["Patrick Barker ", "Jeffrey Huckabay "] license = "MIT" @@ -19,8 +19,9 @@ devicebay = "^0.1.11" litellm = "^1.35.8" rich = "^13.7.1" tqdm = "^4.66.4" -agentdesk = "^0.2.114" taskara = "^0.1.210" +agentcore = "^0.1.2" +agentdesk = "^0.2.118" [tool.poetry.group.dev.dependencies] diff --git a/surfkit/db/models.py b/surfkit/db/models.py index b06d59f..8ff80f7 100644 --- a/surfkit/db/models.py +++ b/surfkit/db/models.py @@ -9,9 +9,7 @@ class SkillRecord(Base): __tablename__ = "skills" - __table_args__ = ( - Index("idx_skill_owner_id", "owner_id"), - ) + __table_args__ = (Index("idx_skill_owner_id", "owner_id"),) id = Column(String, primary_key=True) owner_id = Column(String, nullable=False) name = Column(String, nullable=False) diff --git a/surfkit/server/routes.py b/surfkit/server/routes.py index d30a97b..d130e28 100644 --- a/surfkit/server/routes.py +++ b/surfkit/server/routes.py @@ -3,6 +3,7 @@ import time from typing import Annotated, Optional, Type +from agentcore.models import V1UserProfile from fastapi import APIRouter, BackgroundTasks, Depends from taskara import Task, TaskStatus from taskara.server.models import V1Task, V1Tasks, V1TaskUpdate @@ -11,7 +12,7 @@ from surfkit.agent import TaskAgent from surfkit.auth.transport import get_user_dependency from surfkit.env import AGENTESEA_HUB_API_KEY_ENV -from surfkit.server.models import V1Agent, V1LearnSkill, V1SolveTask, V1UserProfile +from surfkit.server.models import V1Agent, V1LearnSkill, V1SolveTask from surfkit.skill import Skill DEBUG_ENV_VAR = os.getenv("DEBUG", "false").lower() == "true" diff --git a/surfkit/skill.py b/surfkit/skill.py index a055da2..a1bc40f 100644 --- a/surfkit/skill.py +++ b/surfkit/skill.py @@ -41,7 +41,7 @@ def __init__( status: SkillStatus = SkillStatus.NEEDS_DEFINITION, agent_type: Optional[str] = None, owner_id: Optional[str] = None, - example_tasks: Optional[List[Task]]= None, + example_tasks: Optional[List[Task]] = None, min_demos: Optional[int] = None, demos_outstanding: Optional[int] = None, remote: Optional[str] = None, @@ -183,7 +183,9 @@ def from_record(cls, record: SkillRecord) -> "Skill": db.commit() print(f"updated tasks for skill {record.id}", flush=True) except Exception as e: - print(f"Error updating tasks for skill {record.id}: {e}", flush=True) + print( + f"Error updating tasks for skill {record.id}: {e}", flush=True + ) example_task_ids = json.loads(str(record.example_tasks)) if example_task_ids: @@ -195,7 +197,9 @@ def from_record(cls, record: SkillRecord) -> "Skill": example_task_map = {task.id: task for task in example_tasks} for example_task_id in example_task_ids: if not example_task_map[example_task_id]: - print(f"Example Task {example_task_id} not found, removing from skill") + print( + f"Example Task {example_task_id} not found, removing from skill" + ) continue valid_example_task_ids.append(example_task_id) @@ -206,7 +210,10 @@ def from_record(cls, record: SkillRecord) -> "Skill": db.commit() print(f"updated example_tasks for skill {record.id}", flush=True) except Exception as e: - print(f"Error updating example_tasks for skill {record.id}: {e}", flush=True) + print( + f"Error updating example_tasks for skill {record.id}: {e}", + flush=True, + ) requirements = json.loads(str(record.requirements)) @@ -236,33 +243,54 @@ def save(self): db.commit() @classmethod - def find(cls, remote: Optional[str] = None, **kwargs) -> List["Skill"]: # type: ignore + def find( + cls, remote: Optional[str] = None, owners: Optional[List[str]] = None, **kwargs + ) -> List["Skill"]: print("running find for skills", flush=True) if remote: - resp = requests.get(f"{remote}/v1/skills") - skills = [cls.from_v1(skill) for skill in resp.json()] - for key, value in kwargs.items(): - skills = [ - skill for skill in skills if getattr(skill, key, None) == value - ] - + # Prepare query parameters + params = dict(kwargs) + if owners: + # Pass owners as multiple query parameters + for owner in owners: + params.setdefault("owners", []).append(owner) + + print(f"Query params for remote request: {params}", flush=True) + + try: + resp = requests.get(f"{remote}/v1/skills", params=params) + resp.raise_for_status() + except requests.RequestException as e: + print(f"Error fetching skills from remote: {e}", flush=True) + return [] + + skills_json = resp.json() + skills = [cls.from_v1(V1Skill(**skill_data)) for skill_data in skills_json] + + # Set remote attribute for each skill for skill in skills: skill.remote = remote return skills - for db in cls.get_db(): - records = ( - db.query(SkillRecord) - .filter_by(**kwargs) - .order_by(asc(SkillRecord.created)) - .all() - ) - print(f"skills found in db {records}", flush=True) - return [cls.from_record(record) for record in records] + else: + out = [] + for db in cls.get_db(): + query = db.query(SkillRecord) + + # Apply owners filter if provided + if owners: + query = query.filter(SkillRecord.owner_id.in_(owners)) - raise ValueError("no session") + # Apply additional filters from kwargs + for key, value in kwargs.items(): + query = query.filter(getattr(SkillRecord, key) == value) + + records = query.order_by(asc(SkillRecord.created)).all() + print(f"skills found in db {records}", flush=True) + out.extend([cls.from_record(record) for record in records]) + return out def update(self, data: V1UpdateSkill): print(f"updating skill {self.id} data: {data.model_dump_json()}", flush=True) @@ -279,7 +307,9 @@ def update(self, data: V1UpdateSkill): if data.tasks: self.tasks = [Task.find(id=task_id)[0] for task_id in data.tasks] if data.example_tasks: - self.example_tasks = [Task.find(id=task_id)[0] for task_id in data.example_tasks] + self.example_tasks = [ + Task.find(id=task_id)[0] for task_id in data.example_tasks + ] if data.status: self.status = SkillStatus(data.status) if data.min_demos: @@ -322,7 +352,7 @@ def set_generating_tasks(self, input: bool): def get_task_descriptions(self, limit: Optional[int] = None): maxLimit = len(self.tasks) limit = limit if limit and limit < maxLimit else maxLimit - return { "tasks": [task.description for task in self.tasks[-limit:]]} + return {"tasks": [task.description for task in self.tasks[-limit:]]} def generate_tasks( self, @@ -350,7 +380,9 @@ def generate_tasks( example_str = str( f"Some examples of tasks for this skill are: '{json.dumps(example_task_descriptions)}'" ) - example_schema = str('{"tasks": ' f'{json.dumps(example_task_descriptions)}' '}' ) + example_schema = str( + '{"tasks": ' f"{json.dumps(example_task_descriptions)}" "}" + ) old_task_str = "" old_tasks = self.get_task_descriptions(limit=15000) if old_tasks: @@ -363,15 +395,17 @@ def generate_tasks( f"Generating tasks for skill: '{self.description}', skill ID: {self.id} with requirements: {self.requirements}", flush=True, ) - thread = RoleThread(owner_id=self.owner_id) # TODO is this gonna keep one thread? I don't see a need for a new thread every time + thread = RoleThread( + owner_id=self.owner_id + ) # TODO is this gonna keep one thread? I don't see a need for a new thread every time result: List[Task] = [] - + for n in range(n_permutations): print( f"task generation interation: {n} for skill ID {self.id}", - flush=True - ) - + flush=True, + ) + prompt = ( f"Given the agent skill '{self.description}', and the " f"configurable requirements that the agent skill encompasses '{json.dumps(self.requirements)}', " @@ -380,7 +414,7 @@ def generate_tasks( f"Today's date is {current_date}. " f"{example_str} " f"Please return a raw json object that looks like the following example: " - f'{example_schema} ' + f"{example_schema} " f"{old_task_str}" ) print(f"prompt: {prompt}", flush=True) @@ -402,7 +436,9 @@ def generate_tasks( if not gen_tasks: self.set_generating_tasks(False) raise ValueError(f"no tasks generated for skill ID {self.id}") - gen_tasks = gen_tasks[:1] # take only one as we are doing this one at a time + gen_tasks = gen_tasks[ + :1 + ] # take only one as we are doing this one at a time if not self.owner_id: self.set_generating_tasks(False) @@ -432,7 +468,7 @@ def generate_tasks( flush=True, ) result.append(tsk) - self.save() # need to save for every iteration as we want tasks to incrementally show up + self.save() # need to save for every iteration as we want tasks to incrementally show up self.generating_tasks = False self.save() @@ -447,7 +483,7 @@ def generate_tasks( f"Today's date is {current_date}. " f"{example_str} " f"Please return a raw json object that looks like the following example: " - f'{example_schema} ' + f"{example_schema} " f"{old_task_str} " ) thread = RoleThread(owner_id=self.owner_id) diff --git a/surfkit/types.py b/surfkit/types.py index 894a969..dd42a66 100644 --- a/surfkit/types.py +++ b/surfkit/types.py @@ -7,7 +7,7 @@ import requests import yaml -from sqlalchemy import or_ +from sqlalchemy import and_, or_ from surfkit.config import GlobalConfig @@ -257,26 +257,25 @@ def save(self) -> None: session.commit() @classmethod - def find(cls, remote: Optional[str] = None, **kwargs) -> List["AgentType"]: + def find( + cls, remote: Optional[str] = None, owners: Optional[List[str]] = None, **kwargs + ) -> List["AgentType"]: if remote: - logger.debug( - "finding remote agent_types for: ", remote, kwargs.get("owner_id") - ) + logger.debug("finding remote agent_types for: ", remote, owners) - json_data = {} - if kwargs.get("name"): - json_data["name"] = kwargs.get("name") - if kwargs.get("namespace"): - json_data["namespace"] = kwargs.get("namespace") + # Prepare query parameters + params = dict(kwargs) + if owners: + params["owners"] = owners remote_response = cls._remote_request( remote, "GET", "/v1/agenttypes", - json_data=json_data, + params=params, ) - agent_types = V1AgentTypes(**remote_response) if remote_response is not None: + agent_types = V1AgentTypes(**remote_response) out = [ cls.from_v1(record, kwargs.get("owner_id")) for record in agent_types.types @@ -288,31 +287,44 @@ def find(cls, remote: Optional[str] = None, **kwargs) -> List["AgentType"]: else: return [] else: + out = [] for session in cls.get_db(): - records = session.query(AgentTypeRecord).filter_by(**kwargs).all() - return [cls.from_record(record) for record in records] - - return [] + query = session.query(AgentTypeRecord) + if owners: + query = query.filter(AgentTypeRecord.owner_id.in_(owners)) + for key, value in kwargs.items(): + query = query.filter(getattr(AgentTypeRecord, key) == value) + records = query.all() + out.extend([cls.from_record(record) for record in records]) + return out @classmethod def find_for_user( - cls, user_id: str, name: Optional[str] = None, namespace: Optional[str] = None + cls, user_id: str, owners: Optional[List[str]] = None, **kwargs ) -> List["AgentType"]: for session in cls.get_db(): - # Base query - query = session.query(AgentTypeRecord).filter( - or_( - AgentTypeRecord.owner_id == user_id, # type: ignore - AgentTypeRecord.public == True, - ) - ) + # Base filter: agent types owned by user + user_owned_filter = AgentTypeRecord.owner_id == user_id # type: ignore + + # Public agent types filter + public_filter = AgentTypeRecord.public == True + + # Owners filter if provided + if owners: + owners_filter = AgentTypeRecord.owner_id.in_(owners) + # Combine public and owners filters + public_owners_filter = and_(public_filter, owners_filter) + # Combine user-owned and public owners filters + query_filter = or_(user_owned_filter, public_owners_filter) + else: + # No owners filter, include all public and user-owned agent types + query_filter = or_(user_owned_filter, public_filter) - # Conditionally add name filter if name is provided - if name is not None: - query = query.filter(AgentTypeRecord.name == name) + query = session.query(AgentTypeRecord).filter(query_filter) - if namespace is not None: - query = query.filter(AgentTypeRecord.namespace == namespace) + # Process additional filters from kwargs + for key, value in kwargs.items(): + query = query.filter(getattr(AgentTypeRecord, key) == value) records = query.all() return [cls.from_record(record) for record in records] @@ -447,12 +459,12 @@ def _remote_request( addr: str, method: str, endpoint: str, + params: Optional[dict] = None, json_data: Optional[dict] = None, auth_token: Optional[str] = None, ) -> Any: url = f"{addr}{endpoint}" headers = {} - params = None if not auth_token: auth_token = os.getenv(AGENTESEA_HUB_API_KEY_ENV) @@ -465,13 +477,11 @@ def _remote_request( if auth_token: headers["Authorization"] = f"Bearer {auth_token}" - if method.upper() == "GET" and json_data: - params = json_data - try: if method.upper() == "GET": logger.debug("\ncalling remote task GET with url: ", url) logger.debug("\ncalling remote task GET with headers: ", headers) + logger.debug("\ncalling remote task GET with params: ", params) response = requests.get(url, headers=headers, params=params) elif method.upper() == "POST": logger.debug("\ncalling remote task POST with: ", url) @@ -488,26 +498,14 @@ def _remote_request( else: return None - try: - response.raise_for_status() - except requests.exceptions.HTTPError as e: - logger.debug("HTTP Error:", e) - logger.debug("Status Code:", response.status_code) - try: - logger.debug("Response Body:", response.json()) - except ValueError: - logger.debug("Raw Response:", response.text) - raise + response.raise_for_status() logger.debug("response: ", response.__dict__) logger.debug("response.status_code: ", response.status_code) - try: - response_json = response.json() - logger.debug("response_json: ", response_json) - return response_json - except ValueError: - logger.debug("Raw Response:", response.text) - return None + response_json = response.json() + logger.debug("response_json: ", response_json) + return response_json except requests.RequestException as e: + logger.error("Request Exception:", e) raise e