Skip to content

Commit 51fd947

Browse files
committed
Update pre-commit
1 parent 171a1a7 commit 51fd947

File tree

16 files changed

+115
-121
lines changed

16 files changed

+115
-121
lines changed

.pre-commit-config.yaml

Lines changed: 56 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,70 @@ default_language_version:
55

66
repos:
77
- repo: https://github.com/pre-commit/pre-commit-hooks
8-
rev: v5.0.0
8+
rev: v5.0.0 # Latest stable version
99
hooks:
10-
- id: trailing-whitespace
11-
- id: check-ast
1210
- id: check-merge-conflict
1311
- id: check-added-large-files
1412
args: ['--maxkb=1000']
1513
- id: end-of-file-fixer
1614
exclude: '^(.*\.svg)$'
1715

18-
- repo: https://github.com/pycqa/flake8
19-
rev: 34cbf8ef3950f43d09b85e2e45c15ae5717dc37b
16+
- repo: https://github.com/astral-sh/ruff-pre-commit
17+
rev: v0.9.4
2018
hooks:
21-
- id: flake8
22-
files: ^src/llama_stack_client/lib/.* # Only run on files in specific-folder
23-
additional_dependencies:
24-
- flake8-bugbear == 22.4.25
25-
- pep8-naming == 0.12.1
26-
- torchfix
27-
args: ['--config=.flake8']
19+
- id: ruff
20+
files: ^src/llama_stack_client/lib/.*
21+
args: [
22+
--fix,
23+
--exit-non-zero-on-fix
24+
]
25+
- id: ruff-format
26+
files: ^src/llama_stack_client/lib/.*
2827

29-
- repo: https://github.com/pycqa/isort
30-
rev: 5.13.2
28+
- repo: https://github.com/adamchainz/blacken-docs
29+
rev: 1.19.0
3130
hooks:
32-
- id: isort
31+
- id: blacken-docs
3332
files: ^src/llama_stack_client/lib/.*
33+
additional_dependencies:
34+
- black==24.3.0
35+
36+
# - repo: https://github.com/pre-commit/mirrors-mypy
37+
# rev: v1.14.0
38+
# hooks:
39+
# - id: mypy
40+
# additional_dependencies:
41+
# - types-requests
42+
# - types-setuptools
43+
# - pydantic
44+
# args: [--ignore-missing-imports]
45+
46+
# - repo: https://github.com/jsh9/pydoclint
47+
# rev: d88180a8632bb1602a4d81344085cf320f288c5a
48+
# hooks:
49+
# - id: pydoclint
50+
# args: [--config=pyproject.toml]
51+
52+
# - repo: https://github.com/tcort/markdown-link-check
53+
# rev: v3.11.2
54+
# hooks:
55+
# - id: markdown-link-check
56+
# args: ['--quiet']
57+
58+
# - repo: local
59+
# hooks:
60+
# - id: distro-codegen
61+
# name: Distribution Template Codegen
62+
# additional_dependencies:
63+
# - rich
64+
# - pydantic
65+
# entry: python -m llama_stack.scripts.distro_codegen
66+
# language: python
67+
# pass_filenames: false
68+
# require_serial: true
69+
# files: ^llama_stack/templates/.*$
70+
# stages: [manual]
71+
72+
ci:
73+
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
74+
autoupdate_commit_msg: ⬆ [pre-commit.ci] pre-commit autoupdate

.flake8 renamed to .ruff.toml

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,37 @@
1-
[flake8]
21
# Suggested config from pytorch that we can adapt
3-
select = B,C,E,F,N,P,T4,W,B9,TOR0,TOR1,TOR2
4-
max-line-length = 120
2+
lint.select = ["B", "C", "E" , "F" , "N", "W", "B9"]
3+
4+
line-length = 120
5+
56
# C408 ignored because we like the dict keyword argument syntax
67
# E501 is not flexible enough, we're using B950 instead
78
# N812 ignored because import torch.nn.functional as F is PyTorch convention
89
# N817 ignored because importing using acronyms is convention (DistributedDataParallel as DDP)
910
# E731 allow usage of assigning lambda expressions
1011
# E701 let black auto-format statements on one line
1112
# E704 let black auto-format statements on one line
12-
ignore =
13-
E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,N812,N817,E731,E701,E704
13+
lint.ignore = [
14+
"E203", "E305", "E402", "E501", "E721", "E741", "F405", "F821", "F841",
15+
"C408", "E302", "W291", "E303", "N812", "N817", "E731", "E701",
16+
# These are the additional ones we started ignoring after moving to ruff. We should look into each one of them later.
17+
"C901", "C405", "C414", "N803", "N999", "C403", "C416", "B028", "C419", "C401", "B023",
1418
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
1519
# to line this up with executable bit
16-
EXE001,
20+
"EXE001",
1721
# random naming hints don't need
18-
N802,
22+
"N802",
1923
# these ignores are from flake8-bugbear; please fix!
20-
B007,B008,B950
21-
optional-ascii-coding = True
22-
exclude =
23-
./.git,
24-
./docs/*,
25-
./build,
26-
./venv,
27-
*.pyi,
28-
.pre-commit-config.yaml,
29-
*.md,
30-
.flake8
24+
"B007", "B008"
25+
]
26+
27+
exclude = [
28+
"./.git",
29+
"./docs/*",
30+
"./build",
31+
"./scripts",
32+
"./venv",
33+
"*.pyi",
34+
".pre-commit-config.yaml",
35+
"*.md",
36+
".flake8"
37+
]

examples/post_training/supervised_fine_tune_client.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ async def run_main(
3030
checkpoint_dir: Optional[str] = None,
3131
cert_path: Optional[str] = None,
3232
):
33-
3433
# Construct the base URL with the appropriate protocol
3534
protocol = "https" if use_https else "http"
3635
base_url = f"{protocol}://{host}:{port}"
@@ -102,9 +101,7 @@ def main(
102101
cert_path: Optional[str] = None,
103102
):
104103
job_uuid = str(job_uuid)
105-
asyncio.run(
106-
run_main(host, port, job_uuid, model, use_https, checkpoint_dir, cert_path)
107-
)
104+
asyncio.run(run_main(host, port, job_uuid, model, use_https, checkpoint_dir, cert_path))
108105

109106

110107
if __name__ == "__main__":

src/llama_stack_client/_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def to_json(
172172
@override
173173
def __str__(self) -> str:
174174
# mypy complains about an invalid self arg
175-
return f'{self.__repr_name__()}({self.__repr_str__(", ")})' # type: ignore[misc]
175+
return f"{self.__repr_name__()}({self.__repr_str__(', ')})" # type: ignore[misc]
176176

177177
# Override the 'construct' method in a way that supports recursive parsing without validation.
178178
# Based on https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836.

src/llama_stack_client/lib/agents/agent.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,14 @@
99
from llama_stack_client.types import ToolResponseMessage, UserMessage
1010
from llama_stack_client.types.agent_create_params import AgentConfig
1111
from llama_stack_client.types.agents.turn import Turn
12-
from llama_stack_client.types.agents.turn_create_params import (Document,
13-
Toolgroup)
14-
from llama_stack_client.types.agents.turn_create_response import \
15-
AgentTurnResponseStreamChunk
12+
from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup
13+
from llama_stack_client.types.agents.turn_create_response import AgentTurnResponseStreamChunk
1614

1715
from .client_tool import ClientTool
1816

1917
DEFAULT_MAX_ITER = 10
2018

19+
2120
class Agent:
2221
def __init__(
2322
self,
@@ -88,7 +87,7 @@ def create_turn(
8887
pass
8988
if not chunk:
9089
raise Exception("No chunk returned")
91-
if chunk.event.payload.event_type != 'turn_complete':
90+
if chunk.event.payload.event_type != "turn_complete":
9291
raise Exception("Turn did not complete")
9392
return chunk.event.payload.turn
9493

@@ -102,7 +101,7 @@ def _create_turn_streaming(
102101
) -> Iterator[AgentTurnResponseStreamChunk]:
103102
stop = False
104103
n_iter = 0
105-
max_iter = self.agent_config.get('max_infer_iters', DEFAULT_MAX_ITER)
104+
max_iter = self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER)
106105
while not stop and n_iter < max_iter:
107106
response = self.client.agents.turn.create(
108107
agent_id=self.agent_id,

src/llama_stack_client/lib/agents/client_tool.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,7 @@ def parameters_for_system_prompt(self) -> str:
4646
{
4747
"name": self.get_name(),
4848
"description": self.get_description(),
49-
"parameters": {
50-
name: definition.__dict__
51-
for name, definition in self.get_params_definition().items()
52-
},
49+
"parameters": {name: definition.__dict__ for name, definition in self.get_params_definition().items()},
5350
}
5451
)
5552

src/llama_stack_client/lib/agents/event_logger.py

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -60,32 +60,22 @@ def __init__(self):
6060
self.previous_step_type = None
6161

6262
def yield_printable_events(self, chunk):
63-
for printable_event in self._yield_printable_events(
64-
chunk, self.previous_event_type, self.previous_step_type
65-
):
63+
for printable_event in self._yield_printable_events(chunk, self.previous_event_type, self.previous_step_type):
6664
yield printable_event
6765

68-
self.previous_event_type, self.previous_step_type = (
69-
self._get_event_type_step_type(chunk)
70-
)
66+
self.previous_event_type, self.previous_step_type = self._get_event_type_step_type(chunk)
7167

72-
def _yield_printable_events(
73-
self, chunk, previous_event_type=None, previous_step_type=None
74-
):
68+
def _yield_printable_events(self, chunk, previous_event_type=None, previous_step_type=None):
7569
if hasattr(chunk, "error"):
76-
yield TurnStreamPrintableEvent(
77-
role=None, content=chunk.error["message"], color="red"
78-
)
70+
yield TurnStreamPrintableEvent(role=None, content=chunk.error["message"], color="red")
7971
return
8072

8173
if not hasattr(chunk, "event"):
8274
# Need to check for custom tool first
8375
# since it does not produce event but instead
8476
# a Message
8577
if isinstance(chunk, ToolResponseMessage):
86-
yield TurnStreamPrintableEvent(
87-
role="CustomTool", content=chunk.content, color="green"
88-
)
78+
yield TurnStreamPrintableEvent(role="CustomTool", content=chunk.content, color="green")
8979
return
9080

9181
event = chunk.event
@@ -101,9 +91,7 @@ def _yield_printable_events(
10191
if step_type == "shield_call" and event_type == "step_complete":
10292
violation = event.payload.step_details.violation
10393
if not violation:
104-
yield TurnStreamPrintableEvent(
105-
role=step_type, content="No Violation", color="magenta"
106-
)
94+
yield TurnStreamPrintableEvent(role=step_type, content="No Violation", color="magenta")
10795
else:
10896
yield TurnStreamPrintableEvent(
10997
role=step_type,
@@ -114,9 +102,7 @@ def _yield_printable_events(
114102
# handle inference
115103
if step_type == "inference":
116104
if event_type == "step_start":
117-
yield TurnStreamPrintableEvent(
118-
role=step_type, content="", end="", color="yellow"
119-
)
105+
yield TurnStreamPrintableEvent(role=step_type, content="", end="", color="yellow")
120106
elif event_type == "step_progress":
121107
if event.payload.delta.type == "tool_call":
122108
if isinstance(event.payload.delta.tool_call, str):
@@ -167,13 +153,9 @@ def _yield_printable_events(
167153

168154
def _get_event_type_step_type(self, chunk):
169155
if hasattr(chunk, "event"):
170-
previous_event_type = (
171-
chunk.event.payload.event_type if hasattr(chunk, "event") else None
172-
)
156+
previous_event_type = chunk.event.payload.event_type if hasattr(chunk, "event") else None
173157
previous_step_type = (
174-
chunk.event.payload.step_type
175-
if previous_event_type not in {"turn_start", "turn_complete"}
176-
else None
158+
chunk.event.payload.step_type if previous_event_type not in {"turn_start", "turn_complete"} else None
177159
)
178160
return previous_event_type, previous_step_type
179161
return None, None

src/llama_stack_client/lib/cli/configure.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
from prompt_toolkit import prompt
1212
from prompt_toolkit.validation import Validator
1313

14-
from llama_stack_client.lib.cli.constants import (
15-
LLAMA_STACK_CLIENT_CONFIG_DIR, get_config_file_path)
14+
from llama_stack_client.lib.cli.constants import LLAMA_STACK_CLIENT_CONFIG_DIR, get_config_file_path
1615

1716

1817
def get_config():

src/llama_stack_client/lib/cli/inference/inference.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,5 +45,6 @@ def chat_completion(ctx, message: str, stream: bool, model_id: Optional[str]):
4545
for event in EventLogger().log(response):
4646
event.print()
4747

48+
4849
# Register subcommands
4950
inference.add_command(chat_completion)

src/llama_stack_client/lib/cli/llama_stack_client.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,9 @@
2929

3030

3131
@click.group()
32-
@click.version_option(
33-
version=version("llama-stack-client"), prog_name="llama-stack-client"
34-
)
35-
@click.option(
36-
"--endpoint", type=str, help="Llama Stack distribution endpoint", default=""
37-
)
38-
@click.option(
39-
"--api-key", type=str, help="Llama Stack distribution API key", default=""
40-
)
32+
@click.version_option(version=version("llama-stack-client"), prog_name="llama-stack-client")
33+
@click.option("--endpoint", type=str, help="Llama Stack distribution endpoint", default="")
34+
@click.option("--api-key", type=str, help="Llama Stack distribution API key", default="")
4135
@click.option("--config", type=str, help="Path to config file", default=None)
4236
@click.pass_context
4337
def cli(ctx, endpoint: str, api_key: str, config: str | None):

src/llama_stack_client/lib/cli/models/models.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,7 @@ def list_models(ctx):
4444
# Configure columns with specific styling
4545
table.add_column("model_type", style="blue")
4646
table.add_column("identifier", style="bold cyan", no_wrap=True, overflow="fold")
47-
table.add_column(
48-
"provider_resource_id", style="yellow", no_wrap=True, overflow="fold"
49-
)
47+
table.add_column("provider_resource_id", style="yellow", no_wrap=True, overflow="fold")
5048
table.add_column("metadata", style="magenta", max_width=30, overflow="fold")
5149
table.add_column("provider_id", style="green", max_width=20)
5250

src/llama_stack_client/lib/cli/post_training/post_training.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
import click
1010
from rich.console import Console
1111

12-
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
13-
AlgorithmConfigParam, TrainingConfig)
12+
from llama_stack_client.types.post_training_supervised_fine_tune_params import AlgorithmConfigParam, TrainingConfig
1413

1514
from ..common.utils import handle_client_errors
1615

@@ -26,9 +25,7 @@ def post_training():
2625
@click.option("--model", required=True, help="Model ID")
2726
@click.option("--algorithm-config", required=True, help="Algorithm Config")
2827
@click.option("--training-config", required=True, help="Training Config")
29-
@click.option(
30-
"--checkpoint-dir", required=False, help="Checkpoint Config", default=None
31-
)
28+
@click.option("--checkpoint-dir", required=False, help="Checkpoint Config", default=None)
3229
@click.pass_context
3330
@handle_client_errors("post_training supervised_fine_tune")
3431
def supervised_fine_tune(
@@ -65,9 +62,7 @@ def get_training_jobs(ctx):
6562
console = Console()
6663

6764
post_training_jobs = client.post_training.job.list()
68-
console.print(
69-
[post_training_job.job_uuid for post_training_job in post_training_jobs]
70-
)
65+
console.print([post_training_job.job_uuid for post_training_job in post_training_jobs])
7166

7267

7368
@click.command("status")

src/llama_stack_client/lib/cli/shields/shields.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,7 @@ def list(ctx):
4444
)
4545

4646
table.add_column("identifier", style="bold cyan", no_wrap=True, overflow="fold")
47-
table.add_column(
48-
"provider_alias", style="yellow", no_wrap=True, overflow="fold"
49-
)
47+
table.add_column("provider_alias", style="yellow", no_wrap=True, overflow="fold")
5048
table.add_column("params", style="magenta", max_width=30, overflow="fold")
5149
table.add_column("provider_id", style="green", max_width=20)
5250

0 commit comments

Comments
 (0)