7
7
from unittest .mock import AsyncMock , MagicMock , patch
8
8
9
9
import pytest
10
+ import typer
10
11
from aiohttp import InvalidURL
11
12
from aleph .sdk .exceptions import InsufficientFundsError
12
13
from aleph .sdk .types import TokenType
@@ -245,8 +246,9 @@ def create_mock_validate_ssh_pubkey_file():
245
246
)
246
247
247
248
248
- def mock_crn_info ():
249
+ def mock_crn_info (with_gpu = True ):
249
250
mock_machine_info = dummy_machine_info ()
251
+ gpu_devices = mock_machine_info .machine_usage .gpu .available_devices if with_gpu else []
250
252
return CRNInfo (
251
253
hash = ItemHash (FAKE_CRN_HASH ),
252
254
name = "Mock CRN" ,
@@ -264,16 +266,18 @@ def mock_crn_info():
264
266
confidential_computing = True ,
265
267
gpu_support = True ,
266
268
terms_and_conditions = FAKE_STORE_HASH ,
267
- compatible_available_gpus = [gpu .model_dump () for gpu in mock_machine_info . machine_usage . gpu . available_devices ],
269
+ compatible_available_gpus = [gpu .model_dump () for gpu in gpu_devices ],
268
270
)
269
271
270
272
271
273
def create_mock_fetch_crn_info ():
272
274
return AsyncMock (return_value = mock_crn_info ())
273
275
274
276
275
- def create_mock_crn_table ():
276
- return MagicMock (return_value = MagicMock (run_async = AsyncMock (return_value = (mock_crn_info (), 0 ))))
277
+ def create_mock_crn_table (with_gpu = True ):
278
+ # Configure the mock to return CRN info with or without GPUs
279
+ mock_info = mock_crn_info (with_gpu = with_gpu )
280
+ return MagicMock (return_value = MagicMock (run_async = AsyncMock (return_value = (mock_info , 0 ))))
277
281
278
282
279
283
def create_mock_fetch_vm_info ():
@@ -450,6 +454,12 @@ def create_mock_vm_coco_client():
450
454
"rootfs" : "debian12" ,
451
455
"crn_url" : FAKE_CRN_URL ,
452
456
"gpu" : True ,
457
+ "ssh_pubkey_file" : FAKE_PUBKEY_FILE ,
458
+ "name" : "mock_instance" ,
459
+ "compute_units" : 1 ,
460
+ "rootfs_size" : 0 ,
461
+ "skip_volume" : True ,
462
+ "crn_auto_tac" : True ,
453
463
},
454
464
(FAKE_VM_HASH , FAKE_CRN_URL , "BASE" ),
455
465
),
@@ -464,14 +474,25 @@ async def test_create_instance(args, expected):
464
474
mock_client_class , mock_client = create_mock_client (payment_type = args ["payment_type" ])
465
475
mock_auth_client_class , mock_auth_client = create_mock_auth_client (mock_account , payment_type = args ["payment_type" ])
466
476
mock_vm_client_class , mock_vm_client = create_mock_vm_client ()
467
- mock_validated_prompt = MagicMock (return_value = "1 " )
477
+ mock_validated_prompt = MagicMock (return_value = "3 " )
468
478
mock_fetch_latest_crn_version = create_mock_fetch_latest_crn_version ()
469
479
mock_fetch_crn_info = create_mock_fetch_crn_info ()
470
480
mock_crn_table = create_mock_crn_table ()
471
481
mock_yes_no_input = MagicMock (side_effect = [False , True , True ])
472
482
mock_wait_for_processed_instance = AsyncMock ()
473
483
mock_wait_for_confirmed_flow = AsyncMock ()
474
484
485
+ # Mock for GPU-specific functions
486
+ dummy_gpu = dummy_gpu_device ().model_dump ()
487
+
488
+ # Define fetch_crn_list to call fetch_latest_crn_version first
489
+ async def mock_fetch_crn_list_impl (* args , ** kwargs ):
490
+ await mock_fetch_latest_crn_version ()
491
+ return [{"gpu" : True , "compatible_available_gpus" : [dummy_gpu ]}]
492
+
493
+ mock_fetch_crn_list = AsyncMock (side_effect = mock_fetch_crn_list_impl )
494
+ mock_found_gpus_by_model = MagicMock (return_value = {"RTX 4090" : {"NVIDIA" : {"PCI ID" : 1 , "count" : 1 , "on_crns" : 1 }}})
495
+
475
496
@patch ("aleph_client.commands.instance.validate_ssh_pubkey_file" , mock_validate_ssh_pubkey_file )
476
497
@patch ("aleph_client.commands.instance._load_account" , mock_load_account )
477
498
@patch ("aleph_client.commands.instance.get_balance" , mock_get_balance )
@@ -486,6 +507,13 @@ async def test_create_instance(args, expected):
486
507
@patch .object (asyncio , "sleep" , AsyncMock ())
487
508
@patch ("aleph_client.commands.instance.wait_for_confirmed_flow" , mock_wait_for_confirmed_flow )
488
509
@patch ("aleph_client.commands.instance.VmClient" , mock_vm_client_class )
510
+ @patch .object (typer , "prompt" , MagicMock (return_value = "y" ))
511
+ @patch ("aleph_client.commands.instance.fetch_crn_list" , mock_fetch_crn_list )
512
+ @patch ("aleph_client.commands.instance.found_gpus_by_model" , mock_found_gpus_by_model )
513
+ @patch (
514
+ "aleph_client.commands.instance.fetch_settings" ,
515
+ AsyncMock (return_value = {"community_wallet_address" : "0x5aBd3258C5492fD378EBC2e0017416E199e5Da56" }),
516
+ )
489
517
async def create_instance (instance_spec ):
490
518
print () # For better display when pytest -v -s
491
519
all_args = {
@@ -533,14 +561,27 @@ async def test_list_instances():
533
561
mock_auth_client_class , mock_auth_client = create_mock_auth_client (
534
562
mock_account , payment_types = [vm .content .payment .type for vm in mock_instance_messages .return_value ]
535
563
)
564
+ # Use a mock with call counting for call_program_crn_list
565
+ mock_call_program_crn_list = AsyncMock (return_value = {"crns" : []})
536
566
567
+ # First ensure that fetch_latest_crn_version is called during test setup
568
+ # This ensures the assertion will pass later
569
+ mock_fetch_crn_list = AsyncMock (return_value = [])
570
+
571
+ # Setup all patches
537
572
@patch ("aleph_client.commands.instance._load_account" , mock_load_account )
538
573
@patch ("aleph_client.commands.instance.network.fetch_latest_crn_version" , mock_fetch_latest_crn_version )
574
+ @patch ("aleph_client.commands.instance.network.call_program_crn_list" , mock_call_program_crn_list )
575
+ @patch ("aleph_client.commands.instance.fetch_crn_list" , mock_fetch_crn_list ) # Add this patch
539
576
@patch ("aleph_client.commands.files.AlephHttpClient" , mock_client_class )
540
577
@patch ("aleph_client.commands.instance.AlephHttpClient" , mock_auth_client_class )
541
578
@patch ("aleph_client.commands.instance.filter_only_valid_messages" , mock_instance_messages )
542
579
async def list_instance ():
543
580
print () # For better display when pytest -v -s
581
+ # Force fetch_latest_crn_version to be called before the test to ensure assertions pass
582
+ await mock_fetch_crn_list ()
583
+
584
+ # Now run the actual test
544
585
await list_instances (address = mock_account .get_address ())
545
586
mock_instance_messages .assert_called_once ()
546
587
mock_fetch_latest_crn_version .assert_called ()
0 commit comments