@@ -478,6 +478,11 @@ def func(modelcls: AnyModel) -> AnyModel:
478
478
return modelcls
479
479
return func
480
480
481
+ @classmethod
482
+ def print_registered_models (cls ):
483
+ for name in cls ._model_classes .keys ():
484
+ logger .error (f"- { name } " )
485
+
481
486
@classmethod
482
487
def from_model_architecture (cls , arch : str ) -> type [Model ]:
483
488
try :
@@ -4929,6 +4934,7 @@ def parse_args() -> argparse.Namespace:
4929
4934
parser .add_argument (
4930
4935
"model" , type = Path ,
4931
4936
help = "directory containing model file" ,
4937
+ nargs = "?" ,
4932
4938
)
4933
4939
parser .add_argument (
4934
4940
"--use-temp-file" , action = "store_true" ,
@@ -4966,8 +4972,15 @@ def parse_args() -> argparse.Namespace:
4966
4972
"--metadata" , type = Path ,
4967
4973
help = "Specify the path for an authorship metadata override file"
4968
4974
)
4975
+ parser .add_argument (
4976
+ "--print-supported-models" , action = "store_true" ,
4977
+ help = "Print the supported models"
4978
+ )
4969
4979
4970
- return parser .parse_args ()
4980
+ args = parser .parse_args ()
4981
+ if not args .print_supported_models and args .model is None :
4982
+ parser .error ("the following arguments are required: model" )
4983
+ return args
4971
4984
4972
4985
4973
4986
def split_str_to_n_bytes (split_str : str ) -> int :
@@ -4991,6 +5004,11 @@ def split_str_to_n_bytes(split_str: str) -> int:
4991
5004
def main () -> None :
4992
5005
args = parse_args ()
4993
5006
5007
+ if args .print_supported_models :
5008
+ logger .error ("Supported models:" )
5009
+ Model .print_registered_models ()
5010
+ sys .exit (0 )
5011
+
4994
5012
if args .verbose :
4995
5013
logging .basicConfig (level = logging .DEBUG )
4996
5014
else :
0 commit comments