Skip to content

Commit

Permalink
fix trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Oct 23, 2024
1 parent 1328c9e commit 6c5fe5a
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 11 deletions.
21 changes: 12 additions & 9 deletions optimum/intel/neural_compressor/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from transformers import Trainer
from transformers.data.data_collator import DataCollator
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
from transformers.feature_extraction_utils import FeatureExtractionMixin
from transformers.modeling_utils import PreTrainedModel, get_parameter_dtype, unwrap_model
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
Expand Down Expand Up @@ -104,7 +105,7 @@
from neural_compressor.config import _BaseQuantizationConfig


__version__ = "4.22.2"
__version__ = "4.46.0"


logger = logging.get_logger(__name__)
Expand All @@ -122,8 +123,9 @@ def __init__(
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Dataset] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
processing_class: Optional[Union[PreTrainedTokenizerBase, FeatureExtractionMixin]] = None,
model_init: Callable[[], PreTrainedModel] = None,
compute_loss_func: Optional[Callable] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
Expand All @@ -132,6 +134,7 @@ def __init__(
pruning_config: Optional[_BaseQuantizationConfig] = None,
distillation_config: Optional[_BaseQuantizationConfig] = None,
task: Optional[str] = None,
**kwargs,
):
self.neftune_noise_alpha = None

Expand All @@ -141,12 +144,12 @@ def __init__(
data_collator,
train_dataset,
eval_dataset,
tokenizer,
model_init,
compute_metrics,
callbacks,
optimizers,
preprocess_logits_for_metrics,
processing_class or kwargs.get("tokenizer", None),
model_init=model_init,
compute_metrics=compute_metrics,
callbacks=callbacks,
optimizers=optimizers,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)

if self.args.device.type == "cuda" and not is_neural_compressor_version(">", "2.0.0"):
Expand Down Expand Up @@ -766,7 +769,7 @@ def _get_logits(model_outputs):
output_names = ["logits", "start_logits", "end_logits"]
return tuple(model_outputs.get(name) for name in output_names if name in model_outputs)

def compute_loss(self, model, inputs, return_outputs=False):
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
"""
Expand Down
5 changes: 5 additions & 0 deletions optimum/intel/openvino/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,11 @@ def __init__(
logger.warning("OVTrainer is deprecated and will be removed in optimum-intel v1.22.0.")

if is_transformers_version(">=", "4.45.0"):
if is_transformers_version(">=", "4.46.0"):
raise ImportError(
f"Unsupported transformers version found is {_transformers_version} which is not supported by the OVTrainer. Please downgrade to v4.44"
)

logger.warning(
f"The transformers version found is {_transformers_version} which is not officially supported by the OVTrainer, use at your own risk"
)
Expand Down
3 changes: 3 additions & 0 deletions tests/openvino/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,9 @@ class OVTrainerTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = (("albert", 64, 39),)

@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS)
@unittest.skipIf(
is_transformers_version(">=", "4.46"), reason="OVTrainer is not compatible with transformers>=v4.46"
)
def test_aware_training_quantization(self, model_name, expected_fake_quantize, expected_int8):
model_id = MODEL_NAMES[model_name]
model = AutoModelForSequenceClassification.from_pretrained(model_id, attn_implementation="eager")
Expand Down
13 changes: 11 additions & 2 deletions tests/openvino/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,10 @@ class OVTrainerTextClassificationTrainingTest(OVTrainerBaseTrainingTest):
task = "sequence-classification"

@parameterized.expand(OVTRAINER_TEXT_CLASSIFICATION_TEST_DESCRIPTORS.items())
@unittest.skipIf(is_transformers_version("<", "4.41.0"), reason="Mismatch in expected fake quantized op")
@unittest.skipIf(
is_transformers_version("<", "4.41") or is_transformers_version(">=", "4.46"),
reason="Mismatch in expected fake quantized op and incompatible with transformers v4.46",
)
def test_training(self, _, desc: OVTrainerTestDescriptor):
self.run_ovtrainer_training_checks(desc)

Expand Down Expand Up @@ -627,7 +630,10 @@ class OVTrainerImageClassificationTrainingTest(OVTrainerBaseTrainingTest):
@parameterized.expand(OVTRAINER_IMAGE_CLASSIFICATION_TEST_DESCRIPTORS.items())
@pytest.mark.run_slow
@slow
@unittest.skipIf(is_transformers_version("<", "4.41.0"), reason="Mismatch in expected fake quantized op")
@unittest.skipIf(
is_transformers_version("<", "4.41") or is_transformers_version(">=", "4.46"),
reason="Mismatch in expected fake quantized op and incompatible with transformers v4.46",
)
def test_training(self, _, desc: OVTrainerTestDescriptor):
self.run_ovtrainer_training_checks(desc)

Expand Down Expand Up @@ -808,6 +814,9 @@ class OVTrainerAudioClassificationTrainingTest(OVTrainerBaseTrainingTest):
@parameterized.expand(OVTRAINER_AUDIO_CLASSIFICATION_TEST_DESCRIPTORS.items())
@pytest.mark.run_slow
@slow
@unittest.skipIf(
is_transformers_version(">=", "4.46"), reason="OVTrainer is not compatible with transformers>=v4.46"
)
def test_training(self, _, desc: OVTrainerTestDescriptor):
self.run_ovtrainer_training_checks(desc)

Expand Down

0 comments on commit 6c5fe5a

Please sign in to comment.