
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/semi_supervised/plot_semi_supervised_newsgroups.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        Click :ref:`here <sphx_glr_download_auto_examples_semi_supervised_plot_semi_supervised_newsgroups.py>`
        to download the full example code

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_semi_supervised_plot_semi_supervised_newsgroups.py:


================================================
Semi-supervised Classification on a Text Dataset
================================================

In this example, semi-supervised classifiers are trained on the 20 newsgroups
dataset (which will be automatically downloaded).

You can adjust the number of categories by giving their names to the dataset
loader or setting them to `None` to get all 20 of them.

.. GENERATED FROM PYTHON SOURCE LINES 13-110


.. rst-class:: sphx-glr-script-out

.. code-block:: pytb

    Traceback (most recent call last):
      File "/build/scikit-learn-Ye5PqW/scikit-learn-1.4.1.post1+dfsg/examples/semi_supervised/plot_semi_supervised_newsgroups.py", line 27, in <module>
        data = fetch_20newsgroups(
               ^^^^^^^^^^^^^^^^^^^
      File "/build/scikit-learn-Ye5PqW/scikit-learn-1.4.1.post1+dfsg/.pybuild/cpython3_3.12/build/sklearn/utils/_param_validation.py", line 213, in wrapper
        return func(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^
      File "/build/scikit-learn-Ye5PqW/scikit-learn-1.4.1.post1+dfsg/.pybuild/cpython3_3.12/build/sklearn/datasets/_twenty_newsgroups.py", line 286, in fetch_20newsgroups
        cache = _download_20newsgroups(
                ^^^^^^^^^^^^^^^^^^^^^^^
      File "/build/scikit-learn-Ye5PqW/scikit-learn-1.4.1.post1+dfsg/.pybuild/cpython3_3.12/build/sklearn/datasets/_twenty_newsgroups.py", line 76, in _download_20newsgroups
        archive_path = _fetch_remote(ARCHIVE, dirname=target_dir)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/build/scikit-learn-Ye5PqW/scikit-learn-1.4.1.post1+dfsg/.pybuild/cpython3_3.12/build/sklearn/datasets/_base.py", line 1432, in _fetch_remote
        raise IOError('Debian Policy Section 4.9 prohibits network access during build')
    OSError: Debian Policy Section 4.9 prohibits network access during build






|

.. code-block:: default



    import numpy as np

    from sklearn.datasets import fetch_20newsgroups
    from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
    from sklearn.linear_model import SGDClassifier
    from sklearn.metrics import f1_score
    from sklearn.model_selection import train_test_split
    from sklearn.pipeline import Pipeline
    from sklearn.preprocessing import FunctionTransformer
    from sklearn.semi_supervised import LabelSpreading, SelfTrainingClassifier

    # Loading dataset containing first five categories
    data = fetch_20newsgroups(
        subset="train",
        categories=[
            "alt.atheism",
            "comp.graphics",
            "comp.os.ms-windows.misc",
            "comp.sys.ibm.pc.hardware",
            "comp.sys.mac.hardware",
        ],
    )
    print("%d documents" % len(data.filenames))
    print("%d categories" % len(data.target_names))
    print()

    # Parameters
    sdg_params = dict(alpha=1e-5, penalty="l2", loss="log_loss")
    vectorizer_params = dict(ngram_range=(1, 2), min_df=5, max_df=0.8)

    # Supervised Pipeline
    pipeline = Pipeline(
        [
            ("vect", CountVectorizer(**vectorizer_params)),
            ("tfidf", TfidfTransformer()),
            ("clf", SGDClassifier(**sdg_params)),
        ]
    )
    # SelfTraining Pipeline
    st_pipeline = Pipeline(
        [
            ("vect", CountVectorizer(**vectorizer_params)),
            ("tfidf", TfidfTransformer()),
            ("clf", SelfTrainingClassifier(SGDClassifier(**sdg_params), verbose=True)),
        ]
    )
    # LabelSpreading Pipeline
    ls_pipeline = Pipeline(
        [
            ("vect", CountVectorizer(**vectorizer_params)),
            ("tfidf", TfidfTransformer()),
            # LabelSpreading does not support dense matrices
            ("toarray", FunctionTransformer(lambda x: x.toarray())),
            ("clf", LabelSpreading()),
        ]
    )


    def eval_and_print_metrics(clf, X_train, y_train, X_test, y_test):
        print("Number of training samples:", len(X_train))
        print("Unlabeled samples in training set:", sum(1 for x in y_train if x == -1))
        clf.fit(X_train, y_train)
        y_pred = clf.predict(X_test)
        print(
            "Micro-averaged F1 score on test set: %0.3f"
            % f1_score(y_test, y_pred, average="micro")
        )
        print("-" * 10)
        print()


    if __name__ == "__main__":
        X, y = data.data, data.target
        X_train, X_test, y_train, y_test = train_test_split(X, y)

        print("Supervised SGDClassifier on 100% of the data:")
        eval_and_print_metrics(pipeline, X_train, y_train, X_test, y_test)

        # select a mask of 20% of the train dataset
        y_mask = np.random.rand(len(y_train)) < 0.2

        # X_20 and y_20 are the subset of the train dataset indicated by the mask
        X_20, y_20 = map(
            list, zip(*((x, y) for x, y, m in zip(X_train, y_train, y_mask) if m))
        )
        print("Supervised SGDClassifier on 20% of the training data:")
        eval_and_print_metrics(pipeline, X_20, y_20, X_test, y_test)

        # set the non-masked subset to be unlabeled
        y_train[~y_mask] = -1
        print("SelfTrainingClassifier on 20% of the training data (rest is unlabeled):")
        eval_and_print_metrics(st_pipeline, X_train, y_train, X_test, y_test)

        print("LabelSpreading on 20% of the data (rest is unlabeled):")
        eval_and_print_metrics(ls_pipeline, X_train, y_train, X_test, y_test)


.. rst-class:: sphx-glr-timing

   **Total running time of the script:** ( 0 minutes  0.003 seconds)


.. _sphx_glr_download_auto_examples_semi_supervised_plot_semi_supervised_newsgroups.py:


.. only :: html

 .. container:: sphx-glr-footer
    :class: sphx-glr-footer-example



  .. container:: sphx-glr-download sphx-glr-download-python

     :download:`Download Python source code: plot_semi_supervised_newsgroups.py <plot_semi_supervised_newsgroups.py>`



  .. container:: sphx-glr-download sphx-glr-download-jupyter

     :download:`Download Jupyter notebook: plot_semi_supervised_newsgroups.ipynb <plot_semi_supervised_newsgroups.ipynb>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
