Skip to content

Commit b9ed1a4

Browse files
committed
linter update
1 parent 0137827 commit b9ed1a4

File tree

2 files changed

+1
-10
lines changed

2 files changed

+1
-10
lines changed

MaxText/input_pipeline/_syn_data_processing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,4 +114,4 @@ def get_place_holder_synthetic_data(config: pyconfig.HyperParameters):
114114
dataset = tf.data.Dataset.zip((output)) # pytype: disable=wrong-arg-types
115115
dataset = dataset.repeat()
116116
dataset = dataset.batch(config.global_batch_size_to_load // jax.process_count())
117-
return dataset
117+
return dataset

MaxText/input_pipeline/input_pipeline_interface.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,11 @@
1515
"""
1616

1717
"""Input pipeline"""
18-
from collections.abc import Callable
19-
from typing import Any
2018
import functools
2119

22-
import numpy as np
23-
24-
import tensorflow as tf
25-
2620
import jax
27-
import jax.numpy as jnp
2821
from jax.sharding import PartitionSpec as P
2922

30-
from MaxText import multihost_dataloading
3123
from MaxText import pyconfig
3224
from MaxText import max_logging
3325
from MaxText.input_pipeline._grain_data_processing import make_grain_train_iterator, make_grain_eval_iterator
@@ -94,7 +86,6 @@ def create_data_iterator(config: pyconfig.HyperParameters, mesh):
9486
assert len(process_indices_train) == jax.process_count() // config.expansion_factor_real_data
9587
if config.eval_interval > 0:
9688
assert len(process_indices_eval) == jax.process_count() // config.expansion_factor_real_data
97-
9889
# Generate iterator functions according to dataset type
9990
if config.dataset_type in ["tfds", "grain", "hf", "c4_mlperf"]:
10091
if config.dataset_type == "c4_mlperf":

0 commit comments

Comments
 (0)