Practical activity

Duration5h

Presentation & objectives

This practical activity is divided into two parts. First, you will go through a series of exercises to familiarize with functions and variables visibility. This is a fundamental notion to acquire in any programming language.

Then, you will dive in more details on the creation of Python modules. With these, you will be able to organize your codes and make them easier to reuse in future projects. To illustrate this, we propose to write a small library of functions that implement a linear regressor.

Important

The aim of this session is to help you master important notions in computer science. An intelligent programming assistant such as GitHub Copilot, that you may have installed already, will be able to provide you with a solution to these exercises based only on a wisely chosen file name.

For the sake of training, we advise you to disable such tools first.

At the end of the practical activity, we suggest you to work on the exercise again with these tools activated. Following these two steps will improve your skills both fundamentally and practically.

Also, we provide you the solutions to the exercises. Make sure to check them only after you have a solution to the exercises, for comparison purpose! Even if you are sure your solution is correct, please have a look at them, as they sometimes provide additional elements you may have missed.

Activity contents (part 1)

1 — List of random numbers

Remember the practical activity of programming session 1? We played a bit with random numbers, especially to produce a list of such numbers. This is a very common thing to do when programming, so it may be a good idea to make a function out of it. Later, when we need a list of random numbers, we can thus generate one with the function instead of copy-pasting the same code over and over.

Write a function that returns a list of randomly generated natural numbers within an interval. The function takes as input the number of values to generate as well as the bounds of the acceptability interval. You need to use a function from the random library that generates an integer within a given interval.

Here is the prototype of this function: def rand_ints (min: int, max:int, nb: int) -> List[int]. The type hinting for List can be found in library typing.

A default value may be given to each parameter. In case a function is called without an argument for a parameter, then its default value is used. Make sure your function generates 10 numbers between 0 and 100 if arguments are not given.

Do not forget to type your function prototype (i.e., its parameters and returned value) and to comment it.

Correction
# Needed imports
from typing import List
from random import randint



def rand_ints (min: int = 0, max: int = 100, nb: int = 10) -> List[int]:

    """
        Generate a list of random integers between min (included) and max (included).
        In:
            * min: The minimum value for the random integers.
            * max: The maximum value for the random integers.
            * nb:  The number of random integers to generate.
        Out:
            * A list of random integers.
    """

    # Generate nb random integers
    rand_ints = []
    for i in range(nb):
        rand_ints.append(randint(min, max))

    # Return the list of random integers
    return rand_ints



# Example usage with default parameters
random_numbers = rand_ints()
print(random_numbers)

# Example usage with custom parameters
custom_random_numbers = rand_ints(10, 50, 5)
print(custom_random_numbers)
// Needed imports
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

/**
 * To run this code, you need to have Java installed on your computer, then:
 * - Create a file named `Main.java` in a directory of your choice.
 * - Copy this code in the file.
 * - Open a terminal in the directory where the file is located.
 * - Run the command `javac Main.java` to compile the code.
 * - Run the command `java -ea Main` to execute the compiled code.
 * Note: '-ea' is an option to enable assertions in Java.
 */
public class Main {

    /**
     * Generate a list of random integers between min (included) and max (included).
     *
     * @param min The minimum value for the random integers.
     * @param max The maximum value for the random integers.
     * @param nb  The number of random integers to generate.
     * @return    A list of random integers.
     */
    public static List<Integer> randInts(int min, int max, int nb) {
        // Create a list to store the random integers
        List<Integer> randInts = new ArrayList<>();
        Random rand = new Random();

        // Generate nb random integers
        for (int i = 0; i < nb; i++) {
            int randomInt = rand.nextInt((max - min) + 1) + min;
            randInts.add(randomInt);
        }

        // Return the list of random integers
        return randInts;
    }

    /**
     * Overloaded method to use default parameters.
     *
     * @return 10 random integers in [0, 100].
     */
    public static List<Integer> randInts() {
        return randInts(0, 100, 10);
    }

    /**
     * This is the entry point of your program.
     * It contains the first codes that are going to be executed.
     *
     * @param args Command line arguments received.
     */
    public static void main(String[] args) {
        // Example usage with default parameters
        List<Integer> randomNumbers = randInts();
        System.out.println(randomNumbers);

        // Example usage with custom parameters
        List<Integer> customRandomNumbers = randInts(10, 50, 5);
        System.out.println(customRandomNumbers);
    }

}

2 — Filter a list

Functions can take many different types of parameters: basic types, data structures, objects, functions, etc.

Start by defining two simple functions accepting as parameters an integer n. Each should return a boolean:

  • The first function checks that n is a prime number.
  • The second one checks that n is even.
Correction
def is_prime (n: int) -> bool:

    """
        This function tests if a number is prime or not.
        In:
            * n: The number to be tested.
        Out:
            * True if the number is prime, False otherwise.
    """

    # Must be greater than 1
    if n <= 1:
        return False

    # Test all dividers
    for i in range(2, n):
        if n % i == 0:
            return False
        
    # If we reach here, it is prime
    return True



def is_even (n: int) -> bool:

    """
        This function tests if a number is even or not.
        In:
            * n: The number to be tested.
        Out:
            * True if the number is even, False otherwise.
    """

    # Check modulo 2
    return n % 2 == 0



# Test is_prime
print(29, "is prime:", is_prime(29))
print(15, "is prime:", is_prime(15))

# Test is_even
print(29, "is even:", is_even(29))
print(14, "is even:", is_even(14))
/**
 * To run this code, you need to have Java installed on your computer, then:
 * - Create a file named `Main.java` in a directory of your choice.
 * - Copy this code in the file.
 * - Open a terminal in the directory where the file is located.
 * - Run the command `javac Main.java` to compile the code.
 * - Run the command `java -ea Main` to execute the compiled code.
 * Note: '-ea' is an option to enable assertions in Java.
 */
public class Main {

    /**
     * This function tests if a number is prime or not.
     *
     * @param n The number to be tested.
     * @return  True if the number is prime, False otherwise.
     */
    public static boolean isPrime(int n) {
        // Must be greater than 1
        if (n <= 1) {
            return false;
        }

        // Test all dividers
        for (int i = 2; i < n; i++) {
            if (n % i == 0) {
                return false;
            }
        }

        // If we reach here, it is prime
        return true;
    }

    /**
     * This function tests if a number is even or not.
     *
     * @param n The number to be tested.
     * @return  True if the number is even, False otherwise.
     */
    public static boolean isEven(int n) {
        // Check modulo 2
        return n % 2 == 0;
    }

    /**
     * This is the entry point of your program.
     * It contains the first codes that are going to be executed.
     *
     * @param args Command line arguments received.
     */
    public static void main(String[] args) {
        // Test isPrime
        System.out.println(29 + " is prime: " + isPrime(29));
        System.out.println(15 + " is prime: " + isPrime(15));

        // Test isEven
        System.out.println(29 + " is even: " + isEven(29));
        System.out.println(14 + " is even: " + isEven(14));
    }

}

Define a filter_list function that takes as parameters a list of integers and a test function whose prototype is similar to the two previously defined functions. This filter function returns the list of elements that pass the test function.

Correction
# Needed imports
from typing import List, Callable



def is_prime (n: int) -> bool:

    """
        This function tests if a number is prime or not.
        In:
            * n: The number to be tested.
        Out:
            * True if the number is prime, False otherwise.
    """

    # Must be greater than 1
    if n <= 1:
        return False

    # Test all dividers
    for i in range(2, n):
        if n % i == 0:
            return False
        
    # If we reach here, it is prime
    return True



def is_even (n: int) -> bool:

    """
        This function tests if a number is even or not.
        In:
            * n: The number to be tested.
        Out:
            * True if the number is even, False otherwise.
    """

    return n % 2 == 0



def filter_list (l: List[int], f: Callable[[int], bool]) -> List[int]:

    """
        This function filters a list based on a given function.
        In:
            * l: The list to be filtered.
            * f: The function to be used for filtering.
        Out:
            * A new list containing only the elements that pass the filter.
    """

    # Iterate over the list
    new_list = []
    for elem in l:
        if f(elem):
            new_list.append(elem)
    
    # Return the new list
    return new_list



# Test is_prime
print(29, "is prime:", is_prime(29))
print(15, "is prime:", is_prime(15))

# Test is_even
print(29, "is even:", is_even(29))
print(14, "is even:", is_even(14))

# Test filter_list
l = list(range(100))
print("Even numbers:", filter_list(l, is_even))
print("Prime numbers:", filter_list(l, is_prime))
// Needed imports
import java.util.ArrayList;
import java.util.List;
import java.util.function.Predicate;

/**
 * To run this code, you need to have Java installed on your computer, then:
 * - Create a file named `Main.java` in a directory of your choice.
 * - Copy this code in the file.
 * - Open a terminal in the directory where the file is located.
 * - Run the command `javac Main.java` to compile the code.
 * - Run the command `java -ea Main` to execute the compiled code.
 * Note: '-ea' is an option to enable assertions in Java.
 */
public class Main {

    /**
     * This function tests if a number is prime or not.
     *
     * @param n The number to be tested.
     * @return  True if the number is prime, False otherwise.
     */
    public static boolean isPrime(int n) {
        // Must be greater than 1
        if (n <= 1) {
            return false;
        }

        // Test all dividers
        for (int i = 2; i < n; i++) {
            if (n % i == 0) {
                return false;
            }
        }

        // If we reach here, it is prime
        return true;
    }

    /**
     * This function tests if a number is even or not.
     *
     * @param n The number to be tested.
     * @return  True if the number is even, False otherwise.
     */
    public static boolean isEven(int n) {
        // Check modulo 2
        return n % 2 == 0;
    }

    /**
     * This function filters a list based on a given function.
     *
     * @param l The list to be filtered.
     * @param f The function to be used for filtering.
     * @return  A new list containing only the elements that pass the filter.
     */
    public static List<Integer> filterList(List<Integer> l, Predicate<Integer> f) {
        // Iterate over the list
        List<Integer> newList = new ArrayList<>();
        for (int elem : l) {
            if (f.test(elem)) {
                newList.add(elem);
            }
        }

        // Return the new list
        return newList;
    }

    /**
     * This is the entry point of your program.
     * It contains the first codes that are going to be executed.
     *
     * @param args Command line arguments received.
     */
    public static void main(String[] args) {
        // Test isPrime
        System.out.println(29 + " is prime: " + isPrime(29));
        System.out.println(15 + " is prime: " + isPrime(15));

        // Test isEven
        System.out.println(29 + " is even: " + isEven(29));
        System.out.println(14 + " is even: " + isEven(14));

        // Test filterList
        List<Integer> l = new ArrayList<>();
        for (int i = 0; i < 100; i++)
        {
            l.add(i);
        }
        System.out.println("Even numbers: " + filterList(l, Main::isEven));
        System.out.println("Prime numbers: " + filterList(l, Main::isPrime));
    }

}

lambda functions make it possible to define functions without having to name them. Here is an example of use of a lambda function passed as an argument of the filter_list function:

# Same result as is_even, using a lambda function
l = list(range(100))
print("Even numbers:", filter_list(l, lambda x : x % 2 == 0))
/**
 * To run this code, you need to have Java installed on your computer, then:
 * - Create a file named `Main.java` in a directory of your choice.
 * - Copy this code in the file.
 * - Open a terminal in the directory where the file is located.
 * - Run the command `javac Main.java` to compile the code.
 * - Run the command `java -ea Main` to execute the compiled code.
 * Note: '-ea' is an option to enable assertions in Java.
 */
public class Main {

    /**
     * This is the entry point of your program.
     * It contains the first codes that are going to be executed.
     *
     * @param args Command line arguments received.
     */
    public static void main(String[] args) {
        // Initialize the list
        List<Integer> l = new ArrayList<>();
        for (int i = 0; i < 100; i++) {
            l.add(i);
        }

        // Filter the list using a lambda function to find even numbers
        List<Integer> evenNumbers = filterList(l, x -> x % 2 == 0);
    }

}

3 — Visibility

Python uses two data structures (dictionaries) to store local and global variables. Test the following code and analyze its output:

# Define a global string
val_str = "global func1"



def func_1 () -> None:

    """
        A simple function to test local/global visibility.
        In:
            * None.
        Out:
            * None.
    """

    # Define a local string
    val_str : str = "local func1"
    print(val_str)

    # Print the local and global variables
    print("Local variables in func1:", locals())
    print("Global variables in func1:", globals())



def func_2 () -> None:

    """
        A simple function to test local/global visibility.
        In:
            * None.
        Out:
            * None.
    """

    # Print a string
    print(val_str)

    # Print the local and global variables
    print("Local variables in func1:", locals())
    print("Global variables in func1:", globals())



# Call the functions
func_1()
func_2()

# Print the local and global variables
print("Local variables in main:", locals())
print("Global variables in main:", globals())
/**
 * To run this code, you need to have Java installed on your computer, then:
 * - Create a file named `Main.java` in a directory of your choice.
 * - Copy this code in the file.
 * - Open a terminal in the directory where the file is located.
 * - Run the command `javac Main.java` to compile the code.
 * - Run the command `java -ea Main` to execute the compiled code.
 * Note: '-ea' is an option to enable assertions in Java.
 */
public class Main {

    /** Define a global string (class-level variable). */
    static String valStr = "global func1";

    /**
     * A simple function to test local/global visibility.
     */
    public static void func1() {
        // Define a local string
        String valStr = "local func1";
        System.out.println(valStr);

        // Print the local variable (valStr) and global variable (Main.valStr)
        System.out.println("Local valStr in func1: " + valStr);
        System.out.println("Global valStr in func1: " + Main.valStr);
    }

    /**
     * A simple function to test local/global visibility.
     */
    public static void func2() {
        // Print the global string
        System.out.println(Main.valStr);

        // In Java, we can't directly access the local variables of another method
        System.out.println("Global valStr in func2: " + Main.valStr);
    }

    /**
     * This is the entry point of your program.
     * It contains the first codes that are going to be executed.
     *
     * @param args Command line arguments received.
     */
    public static void main(String[] args) {
        // Call the functions
        func1();
        func2();

        // Print the global variable in the main method
        System.out.println("Global valStr in main: " + Main.valStr);
    }

}

Modify the code of func_1 in such a way that its first statement modifies the value of the global variable val_str.

Correction
def func_1 ():

    """
        A simple function to test local/global visibility.
        In:
            * None.
        Out:
            * None.
    """

    # Declare the string global to modify it
    global val_str
    val_str = "modified global"
    print(val_str)
/**
 * To run this code, you need to have Java installed on your computer, then:
 * - Create a file named `Main.java` in a directory of your choice.
 * - Copy this code in the file.
 * - Open a terminal in the directory where the file is located.
 * - Run the command `javac Main.java` to compile the code.
 * - Run the command `java -ea Main` to execute the compiled code.
 * Note: '-ea' is an option to enable assertions in Java.
 */
public class Main {

    /**
     * A simple function to test local/global visibility.
     */
    public static void func1() {
        // Modify the global string
        Main.valStr = "modified global";
        System.out.println(Main.valStr);
    }

}

4 — Immutable vs. mutable parameters

A key concept to manage concerning variables is the difference between mutable and immutable types. To experiment the different between immutable and mutable types, run the following code and analyze its output.

# Needed imports
from typing import List



def func_1 (a : int) -> None:

    """
        A simple function to test mutability.
        In:
            * a: A non-mutable parameter.
        Out:
            * None.
    """

    # Increment a
    a += 1



def func_2 (b : List[int]) -> None:

    """
        A simple function to test mutability.
        In:
            * b: A mutable parameter.
        Out:
            * None.
    """

    # Append an element to the list
    b.append(0)



# Call func_1
x = 1
func_1(x)
print(x)

# Call func_2
y = []
func_2(y)
print(y)
// Needed imports
import java.util.ArrayList;
import java.util.List;

/**
 * To run this code, you need to have Java installed on your computer, then:
 * - Create a file named `Main.java` in a directory of your choice.
 * - Copy this code in the file.
 * - Open a terminal in the directory where the file is located.
 * - Run the command `javac Main.java` to compile the code.
 * - Run the command `java -ea Main` to execute the compiled code.
 * Note: '-ea' is an option to enable assertions in Java.
 */
public class Main {

    /**
     * A simple function to test mutability.
     *
     * @param a A non-mutable parameter.
     */
    public static void func1(int a) {
        a += 1;
    }

    /**
     * A simple function to test mutability.
     *
     * @param b A mutable parameter.
     */
    public static void func2(List<Integer> b) {
        b.add(0);
    }

    /**
     * This is the entry point of your program.
     * It contains the first codes that are going to be executed.
     *
     * @param args Command line arguments received.
     */
    public static void main(String[] args) {
        // Call func1
        int x = 1;
        func1(x);
        System.out.println(x);

        // Call func2
        List<Integer> y = new ArrayList<>();
        func2(y);
        System.out.println(y);
    }

}

Each variable possesses an identifier that can be retrieved using the id(variable) function. Compare the identifier of the variables and arguments of functions to understand why parameters of a mutable type can be modified within the functions.

Correction
# Needed imports
from typing import List



def func_1 (a : int) -> None:

    """
        A simple function to test mutability.
        In:
            * a: A non-mutable parameter.
        Out:
            * None.
    """

    # Show ID
    print("ID of a in func_1 (before increase): ", id(a))
    a += 1
    print("ID of a in func_1 (after increase): ", id(a))



def func_2 (b : List[int]) -> None:

    """
        A simple function to test mutability.
        In:
            * b: A mutable parameter.
        Out:
            * None.
    """

    # Show ID
    print("ID of b in func_2 (before append): ", id(b))
    b.append(0)
    print("ID of b in func_2 (after append): ", id(b))



# Call func_1
x  = 1
print("ID of x (before call): ", id(x))
func_1(x)
print("ID of x (after call): ", id(x))

# Call func_2
y  = []
print("ID of y (before call): ", id(y))
func_2(y)
print("ID of y (after call): ", id(y))
// Needed imports
import java.util.ArrayList;
import java.util.List;

/**
 * To run this code, you need to have Java installed on your computer, then:
 * - Create a file named `Main.java` in a directory of your choice.
 * - Copy this code in the file.
 * - Open a terminal in the directory where the file is located.
 * - Run the command `javac Main.java` to compile the code.
 * - Run the command `java -ea Main` to execute the compiled code.
 * Note: '-ea' is an option to enable assertions in Java.
 */
public class Main {

    /**
     * A simple function to test mutability.
     *
     * @param a A non-mutable parameter.
     */
    public static void func1(int a) {
        // Show ID
        System.out.println("ID of a in func1 (before increase): " + System.identityHashCode(a));
        a += 1;
        System.out.println("ID of a in func1 (after increase): " + System.identityHashCode(a));
    }

    /**
     * A simple function to test mutability.
     *
     * @param b A mutable parameter.
     */
    public static void func2(List<Integer> b) {
        // Show ID
        System.out.println("ID of b in func2 (before append): " + System.identityHashCode(b));
        b.add(0);
        System.out.println("ID of b in func2 (after append): " + System.identityHashCode(b));
    }

    /**
     * This is the entry point of your program.
     * It contains the first codes that are going to be executed.
     *
     * @param args Command line arguments received.
     */
    public static void main(String[] args) {
        // Call func1
        int x = 1;
        System.out.println("ID of x (before call): " + System.identityHashCode(x));
        func1(x);
        System.out.println("ID of x (after call): " + System.identityHashCode(x));

        // Call func2
        List<Integer> y = new ArrayList<>();
        System.out.println("ID of y (before call): " + System.identityHashCode(y));
        func2(y);
        System.out.println("ID of y (after call): " + System.identityHashCode(y));
    }

}

In some situations, it may be needed to avoid modification of function arguments event if they are of a mutable type. After having understood the value copy operation, modify the call to func_2 in such a way that the local list passed as a parameter is not modified.

Correction
# Call func_2
y = []
print("ID of y (before call): ", id(y))
func_2(y.copy())
print("ID of y (after call): ", id(y))
/**
 * To run this code, you need to have Java installed on your computer, then:
 * - Create a file named `Main.java` in a directory of your choice.
 * - Copy this code in the file.
 * - Open a terminal in the directory where the file is located.
 * - Run the command `javac Main.java` to compile the code.
 * - Run the command `java -ea Main` to execute the compiled code.
 * Note: '-ea' is an option to enable assertions in Java.
 */
public class Main {

    /**
     * This is the entry point of your program.
     * It contains the first codes that are going to be executed.
     *
     * @param args Command line arguments received.
     */
    public static void main(String[] args) {
        // Call func2
        List<Integer> y = new ArrayList<>();
        System.out.println("ID of y (before call): " + System.identityHashCode(y));
        func2(new ArrayList<>(y));
        System.out.println("ID of y (after call): " + System.identityHashCode(y));
    }

}

5 — Let’s go back to the good programming practices

A course of the previous session was dedicated to good programming practices.

In Python, the Python Enhancement Proposals 8 (PEP8) is a programming style guide to use as much as possible. To check the adequacy of your code wrt. this guide, the Pylint may be used. It returns an adequacy score as well as recommendations to improve the reading quality of your code.

Use Pylint to evaluate and then enhance the quality of the code you have produced for the two first exerices of this session (List of random numbers and Filter a list).

6 — Static type checking

Among the errors discovered at runtime, type mismatch is one of them. Just reading the following code, try to identify the which runtime error will occur.

def prime_number(n: int) -> bool:

    """
        Check if a number is prime.
        In:
            * n: The number to check.
        Out:
            * True if the number is prime, False otherwise.
    """

    # A prime number is a number greater than 1 that has no divisors other than 1 and itself
    if n < 2:
        return False

    # Check if the number is divisible by any number from 2 to the square root of the number
    for i in range(2, int(n ** 0.5) + 1):
        if n % i == 0:
            return False
    
    # Return True if we can recach this point
    return True

# Test the function
num_test = input("Give a number to check if it is prime: ")
print(f"Is {num_test} a prime number? = {prime_number(num_test)}")

Run the code to indeed raise an error. Test the Mypy type checker, that would have identified this error without having to run it. Having a tool to check type adequacy of a code with thousands of lines may be useful, isn’t it?

7 — Optimize your solutions

What you can do now is to use AI tools such as GitHub Copilot or ChatGPT, either to generate the solution, or to improve the first solution you came up with! Try to do this for all exercises above, to see the differences with your solutions.

Activity contents (part 2)

8 — Building a linear regressor

The purpose of this exercice is to guide you step by step to the implementation of a linear regressor. It emphasizes on the need to organize your code into small functions.

So, what is a linear regressor? Most of you should already be familiar with this machine learning technique. If not though, here is a very short introductory video:

8.1 — Preparation

The starting point is the following code that provides you with basic functions to randomly generate some points and to display them. let’s copy-paste this code in a file named data_manipulation.py.

# Needed imports
from typing import List, Tuple, Optional
import numpy as np
import matplotlib.pyplot as plt



def generate_data ( nb_points:   int = 10,
                    slope:       float = 0.4,
                    noise:       float = 0.2,
                    min_val:     float = -1.0,
                    max_val:     float = 1.0,
                    random_seed: Optional[int] = None
                  ) ->           Tuple[List[float], List[float]]:

    """
        Generate linearly distributed 2D data with added noise.
        This function generates a set of data points (x, y) where x is uniformly distributed between predefined minimum and maximum values.
        Value y is calculated as a linear function of x with a specified inclination and an added random noise within a specified range.
        In:
            * nb_points:   The number of data points to generate.
            * slope:       The slope of the linear function used to generate y values.
            * noise:       The range within which random noise is added to the y values.
            * min_val:     The minimum value of the x coordinates.
            * max_val:     The maximum value of the x coordinates.
            * random_seed: The random seed used to generate the points.
        Out
            * The x coordinates in a first list.
            * The y coordinates in a second list.
    """

    # Set the random seed
    if random_seed is not None:
        np.random.seed(random_seed)

    # Generate the data
    xrand = np.random.uniform(min_val, max_val, size=(nb_points,))
    delta = np.random.uniform(0, noise, size=(nb_points,))
    ymod = slope * xrand + delta
    return list(xrand), list(ymod)



def scatter_data ( xvals: List[float],
                   yvals: List[float]
                 ) ->     None:
    
    """
        Plot the data in 2D space.
        In:
            * x: The x-coordinates of the data points.
            * y: The y-coordinates of the data points.
        Out:
            * None.
    """

    # Set a margin for a nice plot
    margin = 1.1

    # Plot the data
    axis = plt.gca()
    axis.set_xlim((min(xvals) * margin, max(xvals) * margin))
    axis.set_ylim((min(yvals) * margin, max(yvals) * margin))
    plt.scatter(xvals, yvals, color = "firebrick")
// Needed imports
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.knowm.xchart.SwingWrapper;
import org.knowm.xchart.XYChart;
import org.knowm.xchart.XYChartBuilder;

/**
 * This class should appear in a file named "DataManipulation.java".
 */
public class DataManipulation {

    /**
     * Generate linearly distributed 2D data with added noise. This method generates a set of data points (x, y) where x is uniformly distributed between predefined minimum and maximum values.
     * Value y is calculated as a linear function of x with a specified inclination and an added random noise within a specified range.
     *
     * @param nbPoints   The number of data points to generate.
     * @param slope      The slope of the linear function used to generate y values.
     * @param noise      The range within which random noise is added to the y values.
     * @param minVal     The minimum value of the x coordinates.
     * @param maxVal     The maximum value of the x coordinates.
     * @param randomSeed The random seed used to generate the points.
     * @return           A pair of lists: The first list contains the x coordinates, and the second list contains the y coordinates.
     */
    public static List<List<Double>> generateData(int nbPoints,
                                                  double slope,
                                                  double noise,
                                                  double minVal,
                                                  double maxVal,
                                                  Integer randomSeed) {
        // Set the random seed
        Random rand = randomSeed != null ? new Random(randomSeed) : new Random();

        // Generate the data
        List<Double> xVals = new ArrayList<>();
        List<Double> yVals = new ArrayList<>();
        for (int i = 0; i < nbPoints; i++) {
            double x = minVal + rand.nextDouble() * (maxVal - minVal);
            double delta = rand.nextDouble() * noise;
            double y = slope * x + delta;
            xVals.add(x);
            yVals.add(y);
        }
        List<List<Double>> result = new ArrayList<>();
        result.add(xVals);
        result.add(yVals);
        return result;
    }

    /**
     * Plot the data in 2D space.
     *
     * @param xVals The x-coordinates of the data points.
     * @param yVals The y-coordinates of the data points.
     */
    public void scatterData(List<Double> xVals, List<Double> yVals) {
        // Plot the data
        XYChart chart =
                new XYChartBuilder()
                        .width(800)
                        .height(600)
                        .title("Scatter Plot")
                        .xAxisTitle("X")
                        .yAxisTitle("Y")
                        .build();
        chart.addSeries("Data Points", xVals, yVals);
        new SwingWrapper<>(chart).displayChart();
    }

}
Information

To run this code you will probably need to install the matplotlib library. Check the course on installing modules for a reminder.

Now, update the file above to add some simple tests to verify that the functions work as expected. Since we are going to use them as a module, don’t forget to put these tests in the if __name__ == "__main__": block.

Correction
if __name__ == "__main__":

    # Test the generate_data and scatter_data functions
    x, y = generate_data(20, 0.5, 0.2, -1, 1, 42)
    scatter_data(x, y)
    
    # Display the plot
    plt.show()
// Needed imports
import java.util.List;

/**
 * To run this code, you need to have Java installed on your computer, then:
 * - Create a file named `Main.java` in a directory of your choice.
 * - Copy this code in the file.
 * - Open a terminal in the directory where the file is located.
 * - Run the command `javac Main.java` to compile the code.
 * - Run the command `java -ea Main` to execute the compiled code.
 * Note: '-ea' is an option to enable assertions in Java.
 */
public class Main {

    /**
     * This is the entry point of your program.
     * It contains the first codes that are going to be executed.
     *
     * @param args Command line arguments received.
     */
    public static void main(String[] args) {
        // Test the generateData and scatterData functions
        DataManipulation dm = new DataManipulation();
        List<List<Double>> data = dm.generateData(10, 0.4, 0.2, -1.0, 1.0, 42);
        dm.scatterData(data.get(0), data.get(1));
    }

}

Running the file data_manipulation.py, you should observe the following output:

The aim now is to implement, step by step, a simple but fundamental method of data analysis: linear regression. The aim of this exercise is to focus on your code rather than the formulas used to perform linear regression, which you will certainly study in more details later.

As shown in the following figure, the goal is to find the line $y = ax+b$ that best fits the points distribution.

Example of a linear regression

8.2 — Quantifying the error induced by the estimated function

The strategy is to infer a line minimizing an overall orthogonal distance wrt. a set of training points. It is thus necessary to define a function that computes this overall orthogonal distance. A commonly used error measure is the Residual Sum of Squares Error (RSSE).

Considering a set of $n$ points $\mathcal{X} : \{ (x_1, y_1), \ldots, (x_n, y_n) \}$, the slope $a$ and the intercept $b$ of the estimated line, the RSSE is computed as follows (for a sake of simplicity and without loss of genericity, two dimensional points are used in this exercice):

$$RSSE(\mathcal{X}, a, b) = \sum_{i=1}^n (y_i -(a x_i + b))^2$$

In a new file named linear_regression.py, implement the function having the following prototype and that returns the RSSE.

def rsse ( a: float,
           b: float,
           x: List[float],
           y: List[float]
         ) -> float:

    """
        Compute the RSSE of the line defined by a and b acording to the data x and y.
        In:
            * a: Slope.
            * b: Intercept.
            * x: x values of the points.
            * y: y values of the points.
        Out:
            * The computed RSSE.
    """

The line $y = 0.39x + 0.13$ fits the following points with a RSSE $\approx 0.019$: $$\{ (0.11, 0.08), (-0.6, -0.02), (0.7, 0.4), (-0.12, 0.03), (-0.82, -0.2), (-0.36, 0.01) \}$$

Check that for this example your function returns the expected RSSE.

Correction
# Needed imports
from typing import List



def rsse ( a: float,
           b: float,
           x: List[float],
           y: List[float]
         ) -> float:

    """
        Compute the Residual Sum of Squares Error of the line defined by a and b acording to the data x and y.
        In:
            * a: Slope.
            * b: Intercept.
            * x: x values of the points.
            * y: y values of the points.
        Out:
            * The computed RSSE.
    """

    # Compute the RSSE
    rsse = 0.0
    for i in range(len(x)):
        rsse += (y[i] - (a * x[i] + b)) ** 2
    return rsse



if __name__ == "__main__":

    # Test the function
    x = [0.11, -0.6, 0.7, -0.12, -0.82, -0.36]
    y = [0.08, -0.02, 0.4, 0.03, -0.2, 0.01]
    a = 0.39
    b = 0.13
    print("RSSE: ", rsse(a, b, x, y))
// Needed imports
import java.util.Arrays;
import java.util.List;

/**
 * This class should appear in a file named "LinearRegression.java".
 */
public class LinearRegression {

    /**
     * Compute the Residual Sum of Squares Error (RSSE) of the line defined by a and b according to the data x and y.
     *
     * @param a Slope.
     * @param b Intercept.
     * @param x x values of the points.
     * @param y y values of the points.
     * @return  The computed RSSE.
     */
    public double rsse(double a, double b, List<Double> x, List<Double> y) {
        // Compute the RSSE
        double rsse = 0.0;
        for (int i = 0; i < x.size(); i++) {
            rsse += Math.pow((y.get(i) - (a * x.get(i) + b)), 2);
        }
        return rsse;
    }

}

/**
 * To run this code, you need to have Java installed on your computer, then:
 * - Create a file named `Main.java` in a directory of your choice.
 * - Copy this code in the file.
 * - Open a terminal in the directory where the file is located.
 * - Run the command `javac Main.java` to compile the code.
 * - Run the command `java -ea Main` to execute the compiled code.
 * Note: '-ea' is an option to enable assertions in Java.
 */
public class Main {

    /**
     * This is the entry point of your program.
     * It contains the first codes that are going to be executed.
     *
     * @param args Command line arguments received.
     */
    public static void main(String[] args) {
        // Test the RSSE function
        LinearRegression lr = new LinearRegression();
        List<Double> x = Arrays.asList(0.11, -0.6, 0.7, -0.12, -0.82, -0.36);
        List<Double> y = Arrays.asList(0.08, -0.02, 0.4, 0.03, -0.2, 0.01);
        double a = 0.39;
        double b = 0.13;
        System.out.println("RSSE: " + lr.rsse(a, b, x, y));
    }

}

8.3 — The gradient of the error function

So as to minimise the error, it is now necessary to determine how to adjust the slope and intercept of the current line. The adjustment will be be made in the inverse order of the gradient. The error function having two variables, its gradient is computed as follows, where $n$ is the number of points:

$$\nabla RSSE(a,b) = \left(\sum_{i=1}^n -2 x_i(y_i - (ax_i + b)), \sum_{i=1}^n -2(y_i - (ax_i + b))\right)$$

In your linear_regression.py file, implement a function having the following prototype that returns the two values of the gradient. The function involves only a loop over the points to update the two gradient components. You can use the x and y values of the first point to initialize the two gradient components.

def gradient_rsse ( slope:     float,
                    intercept: float,
                    xvals:     List[float],
                    yvals :    List[float]
                  ) ->         Tuple[float, float]:

    """
        Compute the gradient of the RSSE.
        In:
            * slope:     The slope of the current function.
            * intercept: The intercept of the current function.
            * xvals:     x values of the points to fit.
            * yvals:     y values of the points to fit.
        Out:
            * The gradient of the RSSE.
    """
Correction
# Needed imports
from typing import List, Tuple



def rsse ( a: float,
           b: float,
           x: List[float],
           y: List[float]
         ) -> float:

    """
        Compute the Residual Sum of Squares Error of the line defined by a and b acording to the data x and y.
        In:
            * a: Slope.
            * b: Intercept.
            * x: x values of the points.
            * y: y values of the points.
        Out:
            * The computed RSSE.
    """

    # Compute the RSSE
    rsse = 0.0
    for i in range(len(x)):
        rsse += (y[i] - (a * x[i] + b)) ** 2
    return rsse



def gradient_rsse ( slope:     float,
                    intercept: float,
                    xvals:     List[float],
                    yvals :    List[float]
                  ) ->         Tuple[float, float]:

    """
        Compute the gradient of the RSSE.
        In:
            * slope:     The slope of the current function.
            * intercept: The intercept of the current function.
            * xvals:     x values of the points to fit.
            * yvals:     y values of the points to fit.
        Out:
            * The gradient of the RSSE.
    """

    # Compute the gradient
    grad_a = 0.0
    grad_b = 0.0
    for i in range(len(xvals)):
        grad_a += -2 * xvals[i] * (yvals[i] - (slope * xvals[i] + intercept))
        grad_b += -2 * (yvals[i] - (slope * xvals[i] + intercept))
    return grad_a, grad_b



if __name__ == "__main__":

    # Test the RSSE function
    x = [0.11, -0.6, 0.7, -0.12, -0.82, -0.36]
    y = [0.08, -0.02, 0.4, 0.03, -0.2, 0.01]
    a = 0.39
    b = 0.13
    print("RSSE: ", rsse(a, b, x, y))

    # Test the gradient_rsse function
    slope = 0.4
    intercept = 0.2
    print("Gradient: ", gradient_rsse(slope, intercept, x, y))
// Needed imports
import java.util.List;
import java.util.Arrays;

/**
 * This class should appear in a file named "LinearRegression.java".
 */
public class LinearRegression {

    /**
     * Compute the Residual Sum of Squares Error (RSSE) of the line defined by a and b according to the data x and y.
     *
     * @param a Slope.
     * @param b Intercept.
     * @param x x values of the points.
     * @param y y values of the points.
     * @return  The computed RSSE.
     */
    public double rsse(double a, double b, List<Double> x, List<Double> y) {
        // Compute the RSSE
        double rsse = 0.0;
        for (int i = 0; i < x.size(); i++) {
            rsse += Math.pow((y.get(i) - (a * x.get(i) + b)), 2);
        }
        return rsse;
    }

    /**
     * Compute the gradient of the RSSE.
     *
     * @param slope     The slope of the current function.
     * @param intercept The intercept of the current function.
     * @param xvals     x values of the points to fit.
     * @param yvals     y values of the points to fit.
     * @return          A pair of gradients (grad_a, grad_b).
     */
    public double[] gradientRSSE(double slope, double intercept, List<Double> xvals, List<Double> yvals) {
        // Compute the gradient
        double gradA = 0.0;
        double gradB = 0.0;
        for (int i = 0; i < xvals.size(); i++) {
            gradA += -2 * xvals.get(i) * (yvals.get(i) - (slope * xvals.get(i) + intercept));
            gradB += -2 * (yvals.get(i) - (slope * xvals.get(i) + intercept));
        }
        return new double[] {gradA, gradB};
    }

}

/**
 * To run this code, you need to have Java installed on your computer, then:
 * - Create a file named `Main.java` in a directory of your choice.
 * - Copy this code in the file.
 * - Open a terminal in the directory where the file is located.
 * - Run the command `javac Main.java` to compile the code.
 * - Run the command `java -ea Main` to execute the compiled code.
 * Note: '-ea' is an option to enable assertions in Java.
 */
public class Main {

    /**
     * This is the entry point of your program.
     * It contains the first codes that are going to be executed.
     *
     * @param args Command line arguments received.
     */
    public static void main(String[] args) {
        // Test the RSSE function
        LinearRegression lr = new LinearRegression();
        List<Double> x = Arrays.asList(0.11, -0.6, 0.7, -0.12, -0.82, -0.36);
        List<Double> y = Arrays.asList(0.08, -0.02, 0.4, 0.03, -0.2, 0.01);
        double a = 0.39;
        double b = 0.13;
        System.out.println("RSSE: " + lr.rsse(a, b, x, y));

        // Test the gradientRSSE function
        double slope = 0.4;
        double intercept = 0.2;
        double[] gradients = lr.gradientRSSE(slope, intercept, x, y);
        System.out.println("Gradient: grad_a = " + gradients[0] + ", grad_b = " + gradients[1]);
    }

}

8.4 — The gradient descent

The last step of the approach is to leverage the gradient components to adjust the slope and intercept of the line. To do so, we will use an algorithm called gradient descent, which iteratively updates parameters to minimize the gradient of a function. This last process needs two hyper-parameters, the number of rounds of adjustments called “epoch” and the “learning rate”, which describes the quantity of adjustments we make at each epoch.

Define a function called gradient_descent that takes as parameters:

  • the x_values of the points.
  • the y_values of the points.
  • the initial_slope whose default value is 0.
  • the initial_intercept whose default values is 0.
  • the learning_rate whose default value is 0.02.
  • the epochs whose default value is 100.

The function returns the slope and intercept obtained after epochs rounds.

Correction
# Needed imports
from typing import List, Tuple



def rsse ( a: float,
           b: float,
           x: List[float],
           y: List[float]
         ) -> float:

    """
        Compute the Residual Sum of Squared Errors of the line defined by a and b acording to the data x and y.
        In:
            * a: Slope.
            * b: Intercept.
            * x: x values of the points.
            * y: y values of the points.
        Out:
            * The computed RSSE.
    """

    # Compute the RSSE
    rsse = 0.0
    for i in range(len(x)):
        rsse += (y[i] - (a * x[i] + b)) ** 2
    return rsse



def gradient_rsse ( slope:     float,
                    intercept: float,
                    xvals:     List[float],
                    yvals :    List[float]
                  ) ->         Tuple[float, float]:

    """
        Compute the gradient of the RSSE.
        In:
            * slope:     The slope of the current function.
            * intercept: The intercept of the current function.
            * xvals:     x values of the points to fit.
            * yvals:     y values of the points to fit.
        Out:
            * The gradient of the RSSE.
    """

    # Compute the gradient
    grad_a = 0.0
    grad_b = 0.0
    for i in range(len(xvals)):
        grad_a += -2 * xvals[i] * (yvals[i] - (slope * xvals[i] + intercept))
        grad_b += -2 * (yvals[i] - (slope * xvals[i] + intercept))
    return grad_a, grad_b



def gradient_descent ( x_values:          List[float],
                       y_values:          List[float],
                       initial_slope:     float = 0,
                       initial_intercept: float = 0,
                       learning_rate:     float = 0.02,
                       epochs:            int = 100
                     ) ->                 Tuple[float, float]:

    """
        Perform a gradient descent to fit a line to the data.
        In:
            * x_values:          The x values of the data points.
            * y_values:          The y values of the data points.
            * initial_slope:     The initial slope of the line.
            * initial_intercept: The initial intercept of the line.
            * learning_rate:     The learning rate of the gradient descent.
            * epochs:            The number of epochs to perform.
        Out:
            * The slope and intercept of the fitted line.
    """

    # Initialize the slope and intercept
    slope = initial_slope
    intercept = initial_intercept

    # Perform the gradient descent
    for i in range(epochs):

        # Compute the gradient
        grad_a, grad_b = gradient_rsse(slope, intercept, x_values, y_values)

        # Update the slope and intercept
        slope -= learning_rate * grad_a
        intercept -= learning_rate * grad_b

    return slope, intercept



if __name__ == "__main__":

    # Test the RSSE function
    x = [0.11, -0.6, 0.7, -0.12, -0.82, -0.36]
    y = [0.08, -0.02, 0.4, 0.03, -0.2, 0.01]
    a = 0.39
    b = 0.13
    print("RSSE: ", rsse(a, b, x, y))

    # Test the gradient_rsse function
    slope = 0.4
    intercept = 0.2
    print("Gradient: ", gradient_rsse(slope, intercept, x, y))

    # Test the gradient_descent function
    slope, intercept = gradient_descent(x, y)
    print("Slope: ", slope)
// Needed imports
import java.util.List;
import java.util.Arrays;

/**
 * This class should appear in a file named "LinearRegression.java".
 */
public class LinearRegression {

    /**
     * Compute the Residual Sum of Squares Error (RSSE) of the line defined by a and b according to the data x and y.
     *
     * @param a Slope.
     * @param b Intercept.
     * @param x x values of the points.
     * @param y y values of the points.
     * @return  The computed RSSE.
     */
    public double rsse(double a, double b, List<Double> x, List<Double> y) {
        // Compute the RSSE
        double rsse = 0.0;
        for (int i = 0; i < x.size(); i++) {
            rsse += Math.pow((y.get(i) - (a * x.get(i) + b)), 2);
        }
        return rsse;
    }

    /**
     * Compute the gradient of the RSSE.
     *
     * @param slope     The slope of the current function.
     * @param intercept The intercept of the current function.
     * @param xvals     x values of the points to fit.
     * @param yvals     y values of the points to fit.
     * @return          A pair of gradients (grad_a, grad_b).
     */
    public double[] gradientRSSE(double slope, double intercept, List<Double> xvals, List<Double> yvals) {
        // Compute the gradient
        double gradA = 0.0;
        double gradB = 0.0;
        for (int i = 0; i < xvals.size(); i++) {
            gradA += -2 * xvals.get(i) * (yvals.get(i) - (slope * xvals.get(i) + intercept));
            gradB += -2 * (yvals.get(i) - (slope * xvals.get(i) + intercept));
        }
        return new double[] {gradA, gradB};
    }

    /**
     * Perform a gradient descent to fit a line to the data.
     *
     * @param xValues          The x values of the data points.
     * @param yValues          The y values of the data points.
     * @param initialSlope     The initial slope of the line.
     * @param initialIntercept The initial intercept of the line.
     * @param learningRate     The learning rate of the gradient descent.
     * @param epochs           The number of epochs to perform.
     * @return                 The slope and intercept of the fitted line.
     */
    public double[] gradientDescent(
            List<Double> xValues,
            List<Double> yValues,
            double initialSlope,
            double initialIntercept,
            double learningRate,
            int epochs) {
        // Initialize the slope and intercept
        double slope = initialSlope;
        double intercept = initialIntercept;

        // Perform the gradient descent
        for (int i = 0; i < epochs; i++) {
            // Compute the gradient
            double[] gradients = gradientRSSE(slope, intercept, xValues, yValues);
            double gradA = gradients[0];
            double gradB = gradients[1];

            // Update the slope and intercept
            slope -= learningRate * gradA;
            intercept -= learningRate * gradB;
        }
        return new double[] {slope, intercept};
    }

}

/**
 * To run this code, you need to have Java installed on your computer, then:
 * - Create a file named `Main.java` in a directory of your choice.
 * - Copy this code in the file.
 * - Open a terminal in the directory where the file is located.
 * - Run the command `javac Main.java` to compile the code.
 * - Run the command `java -ea Main` to execute the compiled code.
 * Note: '-ea' is an option to enable assertions in Java.
 */
public class Main {

    /**
     * This is the entry point of your program.
     * It contains the first codes that are going to be executed.
     *
     * @param args Command line arguments received.
     */
    public static void main(String[] args) {
        // Test the RSSE function
        LinearRegression lr = new LinearRegression();
        List<Double> x = Arrays.asList(0.11, -0.6, 0.7, -0.12, -0.82, -0.36);
        List<Double> y = Arrays.asList(0.08, -0.02, 0.4, 0.03, -0.2, 0.01);
        double a = 0.39;
        double b = 0.13;
        System.out.println("RSSE: " + lr.rsse(a, b, x, y));

        // Test the gradientRSSE function
        double slope = 0.4;
        double intercept = 0.2;
        double[] gradients = lr.gradientRSSE(slope, intercept, x, y);
        System.out.println("Gradient: grad_a = " + gradients[0] + ", grad_b = " + gradients[1]);

        // Test the gradientDescent function
        double[] result = lr.gradientDescent(x, y, 0, 0, 0.02, 100);
        System.out.println("Fitted line: slope = " + result[0] + ", intercept = " + result[1]);
    }

}

Update your data_manipulation.py file, and create a new function to draw the obtained line. You can enrich an existing figure using plt.gca() to access it.

Correction
# Needed imports
from typing import List, Tuple, Optional
import numpy as np
import matplotlib.pyplot as plt



def generate_data ( nb_points:   int = 10,
                    slope:       float = 0.4,
                    noise:       float = 0.2,
                    min_val:     float = -1.0,
                    max_val:     float = 1.0,
                    random_seed: Optional[int] = None
                  ) ->           Tuple[List[float], List[float]]:

    """
        Generate linearly distributed 2D data with added noise.
        This function generates a set of data points (x, y) where x is uniformly distributed between predefined minimum and maximum values.
        Value y is calculated as a linear function of x with a specified inclination and an added random noise within a specified range.
        In:
            * nb_points:   The number of data points to generate.
            * slope:       The slope of the linear function used to generate y values.
            * noise:       The range within which random noise is added to the y values.
            * min_val:     The minimum value of the x coordinates.
            * max_val:     The maximum value of the x coordinates.
            * random_seed: The random seed used to generate the points.
        Out
            * The x coordinates in a first list.
            * The y coordinates in a second list.
    """

    # Set the random seed
    if random_seed is not None:
        np.random.seed(random_seed)

    # Generate the data
    xrand = np.random.uniform(min_val, max_val, size=(nb_points,))
    delta = np.random.uniform(0, noise, size=(nb_points,))
    ymod = slope * xrand + delta
    return list(xrand), list(ymod)



def scatter_data ( xvals: List[float],
                   yvals: List[float]
                 ) ->     None:
    
    """
        Plot the data in 2D space.
        In:
            * x: The x-coordinates of the data points.
            * y: The y-coordinates of the data points.
        Out:
            * None.
    """

    # Set a margin for a nice plot
    margin = 1.1

    # Plot the data
    axis = plt.gca()
    axis.set_xlim((min(xvals) * margin, max(xvals) * margin))
    axis.set_ylim((min(yvals) * margin, max(yvals) * margin))
    plt.scatter(xvals, yvals, color = "firebrick")



def plot_line (a: float, b: float):
    
    """
        Plot a line in the current plot.
        In:
            * a: The slope of the line.
            * b: The intercept of the line.
        Out:
            * None.
    """

    # Plot the line
    ax = plt.gca()
    ax.axline((0, b), slope=a, color='C0')



if __name__ == "__main__":

    # Test the generate_data and scatter_data functions
    x, y = generate_data(20, 0.5, 0.2, -1, 1, 42)
    scatter_data(x, y)
    
    # Test the plot_line function
    plot_line(-0.5, 0.0)

    # Display the plot
    plt.show()
// Needed imports
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.knowm.xchart.SwingWrapper;
import org.knowm.xchart.XYChart;
import org.knowm.xchart.XYChartBuilder;

/**
 * This class should appear in a file named "DataManipulation.java".
 */
public class DataManipulation {

    /**
     * Generate linearly distributed 2D data with added noise. This method generates a set of data points (x, y) where x is uniformly distributed between predefined minimum and maximum values.
     * Value y is calculated as a linear function of x with a specified inclination and an added random noise within a specified range.
     *
     * @param nbPoints   The number of data points to generate.
     * @param slope      The slope of the linear function used to generate y values.
     * @param noise      The range within which random noise is added to the y values.
     * @param minVal     The minimum value of the x coordinates.
     * @param maxVal     The maximum value of the x coordinates.
     * @param randomSeed The random seed used to generate the points.
     * @return           A pair of lists: The first list contains the x coordinates, and the second list contains the y coordinates.
     */
    public static List<List<Double>> generateData(
            int nbPoints,
            double slope,
            double noise,
            double minVal,
            double maxVal,
            Integer randomSeed) {
        // Set the random seed
        Random rand = randomSeed != null ? new Random(randomSeed) : new Random();

        // Generate the data
        List<Double> xVals = new ArrayList<>();
        List<Double> yVals = new ArrayList<>();
        for (int i = 0; i < nbPoints; i++) {
            double x = minVal + rand.nextDouble() * (maxVal - minVal);
            double delta = rand.nextDouble() * noise;
            double y = slope * x + delta;
            xVals.add(x);
            yVals.add(y);
        }
        List<List<Double>> result = new ArrayList<>();
        result.add(xVals);
        result.add(yVals);
        return result;
    }

    /**
     * Plot the data in 2D space.
     *
     * @param xVals The x-coordinates of the data points.
     * @param yVals The y-coordinates of the data points.
     */
    public XYChart scatterData(List<Double> xVals, List<Double> yVals) {
        // Plot the data
        XYChart chart =
                new XYChartBuilder()
                        .width(800)
                        .height(600)
                        .title("Scatter Plot")
                        .xAxisTitle("X")
                        .yAxisTitle("Y")
                        .build();
        chart.addSeries("Data Points", xVals, yVals);
        return chart;
    }

    /**
     * Plot a line in the current plot.
     *
     * @param a The slope of the line.
     * @param b The intercept of the line.
     */
    public void plotLine(double a, double b, XYChart chart) {
        // Generate the x values (start and end points)
        double xMin = chart.getStyler().getXAxisMin() != null ? chart.getStyler().getXAxisMin() : 0;
        double xMax = chart.getStyler().getXAxisMax() != null ? chart.getStyler().getXAxisMax() : 1;

        // Calculate corresponding y values
        double yMin = a * xMin + b;
        double yMax = a * xMax + b;

        // Add the line to the chart
        chart.addSeries("Line", new double[] {xMin, xMax}, new double[] {yMin, yMax}).setMarker(null);
    }

}

/**
 * To run this code, you need to have Java installed on your computer, then:
 * - Create a file named `Main.java` in a directory of your choice.
 * - Copy this code in the file.
 * - Open a terminal in the directory where the file is located.
 * - Run the command `javac Main.java` to compile the code.
 * - Run the command `java -ea Main` to execute the compiled code.
 * Note: '-ea' is an option to enable assertions in Java.
 */
public class Main {

    /**
     * This is the entry point of your program.
     * It contains the first codes that are going to be executed.
     *
     * @param args Command line arguments received.
     */
    public static void main(String[] args) {
        // Test the generateData and scatterData functions
        DataManipulation dm = new DataManipulation();
        List<List<Double>> data = dm.generateData(10, 0.4, 0.2, -1.0, 1.0, 42);
        var chart = dm.scatterData(data.get(0), data.get(1));

        // Test the plotLine function
        dm.plotLine(-0.5, 0, chart);

        // Display the chart
        new SwingWrapper<>(chart).displayChart();
    }

}

It is now time to package your application. After having followed the tutorials on code organization and modules management, write a script in a file named main.py, that will import necessary functions from data_manipulation.py and linear_regression.py to:

  • Create a list of points.
  • Fit a linear regression to them.
  • Plot the result.

The required modules have to be installed in a virtual environment.

Correction
# Needed imports
from data_manipulation import *
from gradient_descent import *

# Generate the data
x, y = generate_data(nb_points=100, slope=0.5, noise=0.2, random_seed=42)

# Scatter the data
scatter_data(x, y)

# Perform the gradient descent
slope, intercept = gradient_descent(x, y, learning_rate=0.001, epochs=100)

# Plot the line
plot_line(slope, intercept)

# Show the plot
plt.show()
// Needed imports
import java.util.List;
import org.knowm.xchart.SwingWrapper;
import org.knowm.xchart.XYChart;

/**
 * To run this code, you need to have Java installed on your computer, then:
 * - Create a file named `Main.java` in a directory of your choice.
 * - Copy this code in the file.
 * - Open a terminal in the directory where the file is located.
 * - Run the command `javac Main.java` to compile the code.
 * - Run the command `java -ea Main` to execute the compiled code.
 * Note: '-ea' is an option to enable assertions in Java.
 */
public class Main {

    /**
     * This is the entry point of your program.
     * It contains the first codes that are going to be executed.
     *
     * @param args Command line arguments received.
     */
    public static void main(String[] args) {
        // Create an instance of DataManipulation
        DataManipulation dm = new DataManipulation();

        // Generate the data
        List<List<Double>> data = DataManipulation.generateData(100, 0.5, 0.2, -1.0, 1.0, 42);
        List<Double> xVals = data.get(0);
        List<Double> yVals = data.get(1);

        // Scatter the data
        XYChart chart = dm.scatterData(xVals, yVals);

        // Perform the gradient descent
        double[] result = GradientDescent.gradientDescent(xVals, yVals, 0, 0, 0.001, 100);
        double slope = result[0];
        double intercept = result[1];
        System.out.println("Slope: " + slope + ", Intercept: " + intercept);

        // Plot the line
        dm.plotLine(slope, intercept, chart);

        // Show the plot
        new SwingWrapper<>(chart).displayChart();
    }

}

We get the following result:

9 — Optimize your solutions

What you can do now is to use AI tools such as GitHub Copilot or ChatGPT, either to generate the solution, or to improve the first solution you came up with! Try to do this for all exercises above, to see the differences with your solutions.

10 — Using the linear regressor

Let’s continue exercise 8 by introducing a “train/test split”.

The line returned by the linear regression can be used to predict the y value of other points based on their x value.

To experiment and test the predictive power of an inferred line, let us consider two sets of points, one called the train data to infer the line and one called the test data to evaluate the prediction. Here a function you can use to naively split a list of points into two lists according to a ratio. Add it to your data_manipulation.py module.

def split_data_train_test ( x:     List[float],
                            y:     List[float],
                            ratio: float = 0.8
                          ) ->     Tuple[List[float], List[float], List[float], List[float]]:

    """
        Returns the ratio of the data as train points and the remaining points to testing.
        In:
            * x:     The x-coordinates of the data points.
            * y:     The y-coordinates of the data points.
            * ratio: The ratio of the data to use for training.
        Out:
            * The x values for training.
            * The y values for training.
            * The x values for testing.
            * The y values for testing.
    """

    # Compute the split index
    split_index = int(len(x) * ratio)

    # Split the data
    x_train = x[:split_index]
    y_train = y[:split_index]
    x_test = x[split_index:]
    y_test = y[split_index:]
    return x_train, y_train, x_test, y_test
// Needed imports
import java.util.ArrayList;
import java.util.List;

/**
 * This class should appear in a file named "DataManipulation.java".
 */
public class DataManipulation {

    /**
     * Split the data into training and testing sets.
     *
     * @param x     The x-coordinates of the data points.
     * @param y     The y-coordinates of the data points.
     * @param ratio The ratio of the data to use for training.
     * @return      A list containing the x and y values for training and testing.
     */
    public List<List<Double>> splitDataTrainTest(List<Double> x, List<Double> y, double ratio) {
        // Compute the split index
        int splitIndex = (int) (x.size() * ratio);

        // Split the data
        List<Double> xTrain = new ArrayList<>(x.subList(0, splitIndex));
        List<Double> yTrain = new ArrayList<>(y.subList(0, splitIndex));
        List<Double> xTest = new ArrayList<>(x.subList(splitIndex, x.size()));
        List<Double> yTest = new ArrayList<>(y.subList(splitIndex, y.size()));

        // Return the split data as a list of lists
        List<List<Double>> result = new ArrayList<>();
        result.add(xTrain);
        result.add(yTrain);
        result.add(xTest);
        result.add(yTest);
        return result;
    }

}

Use your RSSE function to evaluate the quality of the prediction, for 1,000 points and a 0.8 split.

Correction
# Needed imports
from data_manipulation import *
from gradient_descent import *

# Generate the data
x, y = generate_data(nb_points=1000, slope=0.5, noise=0.2, random_seed=42)
x_train, y_train, x_test, y_test = split_data_train_test(x, y, ratio=0.8)

# Perform the gradient descent
slope, intercept = gradient_descent(x_train, y_train, learning_rate=0.001, epochs=100)

# Compute average RSSE on training set
average_rsse_train = rsse(slope, intercept, x_train, y_train) / len(x_train)
print("Average RSSE on training set:", average_rsse_train)

# Compute average RSSE on test set
average_rsse_test = rsse(slope, intercept, x_test, y_test) / len(x_test)
print("Average RSSE on test set:", average_rsse_test)
// Needed imports
import java.util.List;

/**
 * To run this code, you need to have Java installed on your computer, then:
 * - Create a file named `Main.java` in a directory of your choice.
 * - Copy this code in the file.
 * - Open a terminal in the directory where the file is located.
 * - Run the command `javac Main.java` to compile the code.
 * - Run the command `java -ea Main` to execute the compiled code.
 * Note: '-ea' is an option to enable assertions in Java.
 */
public class Main {

    /**
     * This is the entry point of your program.
     * It contains the first codes that are going to be executed.
     *
     * @param args Command line arguments received.
     */
    public static void main(String[] args) {
        // Create an instance of DataManipulation
        DataManipulation dm = new DataManipulation();

        // Generate the data
        List<List<Double>> data = DataManipulation.generateData(1000, 0.5, 0.2, -1.0, 1.0, 42);
        List<Double> xVals = data.get(0);
        List<Double> yVals = data.get(1);

        // Split the data into training and testing sets
        List<List<Double>> splitData = dm.splitDataTrainTest(xVals, yVals, 0.8);
        List<Double> xTrain = splitData.get(0);
        List<Double> yTrain = splitData.get(1);
        List<Double> xTest = splitData.get(2);
        List<Double> yTest = splitData.get(3);

        // Perform the gradient descent on training data
        double[] result = GradientDescent.gradientDescent(xTrain, yTrain, 0, 0, 0.001, 100);
        double slope = result[0];
        double intercept = result[1];
        System.out.println("Slope: " + slope + ", Intercept: " + intercept);

        // Compute average RSSE on training set
        double averageRsseTrain = LinearRegression.rsse(slope, intercept, xTrain, yTrain) / xTrain.size();
        System.out.println("Average RSSE on training set: " + averageRsseTrain);

        // Compute average RSSE on test set
        double averageRsseTest = LinearRegression.rsse(slope, intercept, xTest, yTest) / xTest.size();
        System.out.println("Average RSSE on test set: " + averageRsseTest);
    }

}

To go further

12 — A basic textual Space Invaders

To continue experimenting with functions, let’s move on to a less scientific application by implementing what your father will surely consider to be the best video game of all time: Space Invaders (in textual mode).

The implementation of the game relies on two types of data structures not explicitly seen so far: lists and dictionaries. Short introductions to these two built-in data structures are available here.

12.1 — Game display and keybord interactions

Create a new VSCode workspace that will contain your implementation of the Space Invaders game. Copy the keyboardcapture.py file in your workspace. It provides a function to capture some keys of your keyboard without blocking the rest of the game. The keys used are :

  • ‘k’ to move the cannon to the left.
  • ’m’ to move the cannon to the right.
  • ‘o’ to shoot
  • ‘q’ to quit the game.

Then initiate a file named spaceinvaders.py with the following contents, that simply displays the game board and runs the main loop.

"""
    A homemade simplified version of the Space Invaders game.
"""

# Needed imports
import os
from typing import List, Dict, Tuple, Any, Optional
import keyboardcapture
from time import sleep



# Constants
WIDTH = 30
HEIGHT = 15
SLEEP_TIME = 0.08
LEFT_KEY = "k"
RIGHT_KEY = "m"
SHOOT_KEY = "o"
QUIT_KEY = "q"



def display_game ( game:   Dict[str, int],
                   cannon: Dict[str, Any],
                   aliens: List[Dict[str, Any]]
                 ) ->      None :


    """
        Function to display the game board.
        In:
            * game:   The game dictionnary, containing the parameters of the game board (life, score, level).
            * cannon: The cannon dictionnary, containing the parameters of the player's cannon (pos_x, tir).
            * aliens: The aliens list containing the parameters of the aliens (pos_x, pos_y, direction).
        Out:
            * None.
    """

    # Clear the screen
    os.system("clear")
    
    # Display the top info
    print("+" + "-" * WIDTH + "+", "   Left:  <" + LEFT_KEY + ">")
    print("|" + "SCORE   LIFE   LEVEL".center(WIDTH, " ") + "|", "   Right: <" + RIGHT_KEY + ">")
    print("|" + (str(game["score"]).center(5, " ") + "   " + str(game["life"]).center(4, " ") + "   " + str(game["level"]).center(5, " ")).center(WIDTH, " ") + "|", "   Shoot: <" + SHOOT_KEY + ">")
    print("+" + "-" * WIDTH + "+", "   Quit:  <" + QUIT_KEY + ">")

    # Scan all the cells of the game board to check what has to be displayeed
    for i in range(HEIGHT):
        print("|", end="")
        for j in range(WIDTH):

            # Display the cannon where needed
            if i == HEIGHT - 1 and cannon["pos_x"] == j:
                print("#", end="")

            # If there is nothing, we draw a blank cell
            else:
                print(" ", end="")
        
        # Next line
        print("|")

    # Display the top separation line
    print("+" + "-" * WIDTH + "+")



def start () -> None:

    """
        The main function of the game.
        In:
            * None.
        Out:
            * None.
    """

    # Initialize the game
    game = {"life": 1, "score": 0, "level": 1}
    cannon = {"pos_x": WIDTH // 2, "shoot": 1}
    aliens = []

    # To get the keyboard hit
    kb = keyboardcapture.KeyboardCapture()

    # Main loop of the game
    current_action = ""
    while current_action != QUIT_KEY:

        # Catch the keyboard hit in the list of possible keys
        current_action = kb.get_char([LEFT_KEY, RIGHT_KEY, SHOOT_KEY, QUIT_KEY])

        # Display the current state of the game and sleep for a while
        display_game(game, cannon, aliens)
        sleep(SLEEP_TIME / game["level"])

    # Reset the terminal
    kb.set_normal_term()



# When this script is run directly, execute the following code
if __name__ == "__main__":

    # Start the game
    start()
// Not available yet

12.2 — Initializing and displaying the aliens

Define a function that takes as parameters:

  • An optional parameter for the number of aliens to create (default value 20)
  • An optional parameter indicating the percentage of line to be filled with aliens (default value 0.4).

You have to create a dictionary for each alien containing the following keys:

  • pos_x – Its horizontal position.
  • pos_y – Its vertical position.
  • direction – A flag indicating if the aliens move to the right or to the left. Note that all aliens move in the same direction at the beginning.

Your function should return a list of such dictionaries.

Complete the display_game and start functions to integrate the aliens in the game board. Here is a figure showing a possible first positioning of the aliens:

Correction
"""
    A homemade simplified version of the Space Invaders game.
"""

# Needed imports
import os
from typing import List, Dict, Tuple, Any, Optional
import keyboardcapture
from time import sleep



# Constants
WIDTH = 30
HEIGHT = 15
SLEEP_TIME = 0.08
LEFT_KEY = "k"
RIGHT_KEY = "m"
SHOOT_KEY = "o"
QUIT_KEY = "q"



def display_game ( game:   Dict[str, int],
                   cannon: Dict[str, Any],
                   aliens: List[Dict[str, Any]]
                 ) ->      None:


    """
        Function to display the game board.
        In:
            * game:   The game dictionnary, containing the parameters of the game board (life, score, level).
            * cannon: The cannon dictionnary, containing the parameters of the player's cannon (pos_x, tir).
            * aliens: The aliens list containing the parameters of the aliens (pos_x, pos_y, direction).
        Out:
            * None.
    """

    # Clear the screen
    os.system("clear")
    
    # Display the top info
    print("+" + "-" * WIDTH + "+", "   Left:  <" + LEFT_KEY + ">")
    print("|" + "SCORE   LIFE   LEVEL".center(WIDTH, " ") + "|", "   Right: <" + RIGHT_KEY + ">")
    print("|" + (str(game["score"]).center(5, " ") + "   " + str(game["life"]).center(4, " ") + "   " + str(game["level"]).center(5, " ")).center(WIDTH, " ") + "|", "   Shoot: <" + SHOOT_KEY + ">")
    print("+" + "-" * WIDTH + "+", "   Quit:  <" + QUIT_KEY + ">")

    # Scan all the cells of the game board to check what has to be displayeed
    for i in range(HEIGHT):
        print("|", end="")
        for j in range(WIDTH):

            ################################### NEW
            # Determine if an alien is present at this position
            alien_present = False
            for alien in aliens:
                if alien["pos_x"] == j and alien["pos_y"] == i:
                    alien_present = True
                    break
            ################################### /NEW

            # Display the cannon where needed
            if i == HEIGHT - 1 and cannon["pos_x"] == j:
                print("#", end="")
            
            ################################### NEW
            # Display the aliens where needed
            elif alien_present:
                print("@", end="")
            ################################### /NEW
            
            # If there is nothing, we draw a blank cell
            else:
                print(" ", end="")
        
        # Next line
        print("|")

    # Display the top separation line
    print("+" + "-" * WIDTH + "+")



################################### NEW
def init_aliens ( nb_aliens:         int = 20,
                  ratio_line_filled: float = 0.4
                ) ->                 List[Dict[str, Any]]:
    
    """
        Function to initialize the list of the aliens.
        In this version, they start on top at the center.
        In:
            * nb_aliens:         The number of aliens.
            * ratio_line_filled: The ratio of the line to be filled with aliens.
        Out:
            * aliens: The list of aliens, each alien being a dictionnary with the parameters (pos_x, pos_y, direction).
    """
    
    # Initialize the list of aliens
    aliens = []
    
    # Determine the number of alien lines
    full_line_size = int(WIDTH * ratio_line_filled)
    nb_full_lines = nb_aliens // full_line_size
    final_line = nb_aliens % full_line_size
    full_line_margin = (WIDTH - full_line_size) // 2
    final_line_margin = (WIDTH - final_line) // 2

    # Create the aliens
    for l in range(nb_full_lines):
        for c in range(full_line_size):
            aliens.append({"pos_x": full_line_margin + c, "pos_y": l, "direction": "right"})
    for c in range(final_line):
        aliens.append({"pos_x": final_line_margin + c, "pos_y": nb_full_lines, "direction": "right"})

    # Return the list of aliens
    return aliens
################################### /NEW



def start () -> None:

    """
        The main function of the game.
        In:
            * None.
        Out:
            * None.
    """

    # Initialize the game
    game = {"life": 1, "score": 0, "level": 1}
    cannon = {"pos_x": WIDTH // 2, "shoot": False}
    ################################### NEW
    aliens = init_aliens()
    ################################### /NEW

    # To get the keyboard hit
    kb = keyboardcapture.KeyboardCapture()

    # Main loop of the game
    current_action = ""
    while current_action != QUIT_KEY:

        # Catch the keyboard hit in the list of possible keys
        current_action = kb.get_char([LEFT_KEY, RIGHT_KEY, SHOOT_KEY, QUIT_KEY])

        # Display the current state of the game and sleep for a while
        display_game(game, cannon, aliens)
        sleep(SLEEP_TIME / game["level"])

    # Reset the terminal
    kb.set_normal_term()



# When this script is run directly, execute the following code
if __name__ == "__main__":

    # Start the game
    start()
// Not available yet

12.3 — Move the aliens and the cannon

At each turn of the loop, your cannon may move according to the action typed. Identify interesting functions to write, and then modify the main loop to take into account cannon moves.

Then, as illustrated in the next figure, the aliens move altogether from left to right. When one of them hits the right border, they all move down once and will further head to the left. The key direction in the alien directory indicates the current direction.

Correction
"""
    A homemade simplified version of the Space Invaders game.
"""

# Needed imports
import os
from typing import List, Dict, Union, Any, Optional
import keyboardcapture
from time import sleep


# Constants
WIDTH = 30
HEIGHT = 15
SLEEP_TIME = 0.08
LEFT_KEY = "k"
RIGHT_KEY = "m"
SHOOT_KEY = "o"
QUIT_KEY = "q"



def display_game ( game:   Dict[str, int],
                   cannon: Dict[str, Any],
                   aliens: List[Dict[str, Any]]
                 ) ->      None:


    """
        Function to display the game board.
        In:
            * game:   The game dictionary, containing the parameters of the game board (life, score, level).
            * cannon: The cannon dictionary, containing the parameters of the player's cannon (pos_x, shoot).
            * aliens: The aliens list containing the parameters of the aliens (pos_x, pos_y, direction).
        Out:
            * None.
    """

    # Clear the screen
    os.system("clear")
    
    # Display the top info
    print("+" + "-" * WIDTH + "+", "   Left:  <" + LEFT_KEY + ">")
    print("|" + "SCORE   LIFE   LEVEL".center(WIDTH, " ") + "|", "   Right: <" + RIGHT_KEY + ">")
    print("|" + (str(game["score"]).center(5, " ") + "   " + str(game["life"]).center(4, " ") + "   " + str(game["level"]).center(5, " ")).center(WIDTH, " ") + "|", "   Shoot: <" + SHOOT_KEY + ">")
    print("+" + "-" * WIDTH + "+", "   Quit:  <" + QUIT_KEY + ">")

    # Scan all the cells of the game board to check what has to be displayeed
    for i in range(HEIGHT):
        print("|", end="")
        for j in range(WIDTH):

            # Determine if an alien is present at this position
            alien_present = False
            for alien in aliens:
                if alien["pos_x"] == j and alien["pos_y"] == i:
                    alien_present = True
                    break

            # Display the cannon where needed
            if i == HEIGHT - 1 and cannon["pos_x"] == j:
                print("#", end="")
            
            # Display the aliens where needed
            elif alien_present:
                print("@", end="")
            
            # If there is nothing, we draw a blank cell
            else:
                print(" ", end="")
        
        # Next line
        print("|")

    # Display the top separation line
    print("+" + "-" * WIDTH + "+")



def init_aliens ( nb_aliens:         int = 20,
                  ratio_line_filled: float = 0.4
                ) ->                 List[Dict[str, Any]]:
    
    """
        Function to initialize the list of the aliens.
        In this version, they start on top at the center.
        In:
            * nb_aliens:         The number of aliens.
            * ratio_line_filled: The ratio of the line to be filled with aliens.
        Out:
            * aliens: The list of aliens, each alien being a dictionary with the parameters (pos_x, pos_y, direction).
    """
    
    # Initialize the list of aliens
    aliens = []
    
    # Determine the number of alien lines
    full_line_size = int(WIDTH * ratio_line_filled)
    nb_full_lines = nb_aliens // full_line_size
    final_line = nb_aliens % full_line_size
    full_line_margin = (WIDTH - full_line_size) // 2
    final_line_margin = (WIDTH - final_line) // 2

    # Create the aliens
    for l in range(nb_full_lines):
        for c in range(full_line_size):
            aliens.append({"pos_x": full_line_margin + c, "pos_y": l, "direction": "right"})
    for c in range(final_line):
        aliens.append({"pos_x": final_line_margin + c, "pos_y": nb_full_lines, "direction": "right"})

    # Return the list of aliens
    return aliens



################################### NEW
def move_cannon ( cannon: Dict[str, Any],
                  action: str
                ) ->      None:
    
    """
        Function to move the cannon.
        This function will update the cannon dictionary directly.
        In:
            * cannon: The cannon dictionary, containing the parameters of the player's cannon (pos_x, shoot).
            * action: The action asked by the player.
        Out:
            * None.
    """
    
    # Move the cannon left
    if action == LEFT_KEY and cannon["pos_x"] > 0:
        cannon["pos_x"] -= 1

    # Move the cannon right
    elif action == RIGHT_KEY and cannon["pos_x"] < WIDTH - 1:
        cannon["pos_x"] += 1
################################### /NEW



################################### NEW
def move_aliens ( aliens: List[Dict[str, Any]]
                ) ->      None:
    
    """
        Function to move the aliens.
        This function will update the aliens dictionary directly.
        In:
            * aliens: The aliens list containing the parameters of the aliens (pos_x, pos_y, direction).
        Out:
            * None.
    """

    # If an alien reaches the border, all aliens goes down and changes direction
    if any([(alien["pos_x"] == 0 and alien["direction"] == "left") or (alien["pos_x"] == WIDTH - 1 and alien["direction"] == "right") for alien in aliens]):
        for alien in aliens:
            alien["pos_y"] += 1
            alien["direction"] = "left" if alien["direction"] == "right" else "right"

    # Otherwise, all aliens move in their direction
    else:
        for alien in aliens:
            alien["pos_x"] += 1 if alien["direction"] == "right" else -1
################################### /NEW



def start () -> None:

    """
        The main function of the game.
        In:
            * None.
        Out:
            * None.
    """

    # Initialize the game
    game = {"life": 1, "score": 0, "level": 1}
    cannon = {"pos_x": WIDTH // 2, "shoot": False}
    aliens = init_aliens()

    # To get the keyboard hit
    kb = keyboardcapture.KeyboardCapture()

    # Main loop of the game
    current_action = ""
    while current_action != QUIT_KEY:

        # Catch the keyboard hit in the list of possible keys
        current_action = kb.get_char([LEFT_KEY, RIGHT_KEY, SHOOT_KEY, QUIT_KEY])

        ################################### NEW
        # Move the cannon and the aliens
        if current_action in [LEFT_KEY, RIGHT_KEY]:
            move_cannon(cannon, current_action)
        move_aliens(aliens)
        ################################### /NEW

        # Display the current state of the game and sleep for a while
        display_game(game, cannon, aliens)
        sleep(SLEEP_TIME / game["level"])

    # Reset the terminal
    kb.set_normal_term()



# When this script is run directly, execute the following code
if __name__ == "__main__":

    # Start the game
    start()
// Not available yet

12.4 — Shoot aliens

When the o key is pressed, the cannon raises a vertical shoot. To take into account this action, add a function that checks if an alien is touched by this shoot. In this case, remove the alien from the list.

The function displaying the game board has to be updated to show the shoot.

Modify the continuation condition of your main game loop to check that game is not finish (no more aliens or your cannon is destroyed by an alien).

Correction
"""
    A homemade simplified version of the Space Invaders game.
"""

# Needed imports
import os
from typing import List, Dict, Union, Any, Optional
import keyboardcapture
from time import sleep



# Constants
WIDTH = 30
HEIGHT = 15
SLEEP_TIME = 0.08
LEFT_KEY = "k"
RIGHT_KEY = "m"
SHOOT_KEY = "o"
QUIT_KEY = "q"



def display_game ( game:   Dict[str, int],
                   cannon: Dict[str, Any],
                   aliens: List[Dict[str, Any]]
                 ) ->      None:


    """
        Function to display the game board.
        In:
            * game:   The game dictionary, containing the parameters of the game board (life, score, level).
            * cannon: The cannon dictionary, containing the parameters of the player's cannon (pos_x, tir).
            * aliens: The aliens list containing the parameters of the aliens (pos_x, pos_y, direction).
        Out:
            * None.
    """

    # Clear the screen
    os.system("clear")
    
    # Display the top info
    print("+" + "-" * WIDTH + "+", "   Left:  <" + LEFT_KEY + ">")
    print("|" + "SCORE   LIFE   LEVEL".center(WIDTH, " ") + "|", "   Right: <" + RIGHT_KEY + ">")
    print("|" + (str(game["score"]).center(5, " ") + "   " + str(game["life"]).center(4, " ") + "   " + str(game["level"]).center(5, " ")).center(WIDTH, " ") + "|", "   Shoot: <" + SHOOT_KEY + ">")
    print("+" + "-" * WIDTH + "+", "   Quit:  <" + QUIT_KEY + ">")

    ################################### NEW
    # Determine if an alien is being shot
    alien_shot_y = None
    if cannon["fire"]:
        aliens_in_column = [al["pos_y"] for al in aliens if al["pos_x"] == cannon["pos_x"]]
        if len(aliens_in_column) > 0:
            alien_shot_y = max(aliens_in_column)
    ################################### /NEW

    # Scan all the cells of the game board to check what has to be displayeed
    for i in range(HEIGHT):
        print("|", end="")
        for j in range(WIDTH):

            # Determine if an alien is present at this position
            alien_present = False
            for alien in aliens:
                if alien["pos_x"] == j and alien["pos_y"] == i:
                    alien_present = True
                    break

            # Display the cannon where needed
            if i == HEIGHT - 1 and cannon["pos_x"] == j:
                print("#", end="")
            
            ################################### NEW
            # Display the aliens where needed
            elif alien_present:
                if i == alien_shot_y and j == cannon["pos_x"]:
                    print("*", end="")
                else:
                    print("@", end="")
            
            # Display the cannon fire where needed
            elif cannon["fire"] and cannon["pos_x"] == j and (alien_shot_y if alien_shot_y is not None else -1) < i < HEIGHT - 1:
                print(":", end="")
            ################################### /NEW

            # If there is nothing, we draw a blank cell
            else:
                print(" ", end="")
        
        # Next line
        print("|")

    # Display the top separation line
    print("+" + "-" * WIDTH + "+")



def init_aliens ( nb_aliens:         int = 20,
                  ratio_line_filled: float = 0.4
                ) ->                 List[Dict[str, int]]:
    
    """
        Function to initialize the list of the aliens.
        In this version, they start on top at the center.
        In:
            * nb_aliens:         The number of aliens.
            * ratio_line_filled: The ratio of the line to be filled with aliens.
        Out:
            * aliens: The list of aliens, each alien being a dictionary with the parameters (pos_x, pos_y, direction).
    """
    
    # Initialize the list of aliens
    aliens = []
    
    # Determine the number of alien lines
    full_line_size = int(WIDTH * ratio_line_filled)
    nb_full_lines = nb_aliens // full_line_size
    final_line = nb_aliens % full_line_size
    full_line_margin = (WIDTH - full_line_size) // 2
    final_line_margin = (WIDTH - final_line) // 2

    # Create the aliens
    for l in range(nb_full_lines):
        for c in range(full_line_size):
            aliens.append({"pos_x": full_line_margin + c, "pos_y": l, "direction": "right"})
    for c in range(final_line):
        aliens.append({"pos_x": final_line_margin + c, "pos_y": nb_full_lines, "direction": "right"})

    # Return the list of aliens
    return aliens



def move_cannon ( cannon: Dict[str, int],
                  action: str
                ) ->      None:
    
    """
        Function to move the cannon.
        This function will update the cannon dictionary directly.
        In:
            * cannon: The cannon dictionary, containing the parameters of the player's cannon (pos_x, fire).
            * action: The action asked by the player.
        Out:
            * None.
    """
    
    # Move the cannon left
    if action == LEFT_KEY and cannon["pos_x"] > 0:
        cannon["pos_x"] -= 1

    # Move the cannon right
    elif action == RIGHT_KEY and cannon["pos_x"] < WIDTH - 1:
        cannon["pos_x"] += 1



def move_aliens ( aliens: List[Dict[str, Any]]
                ) ->      None:
    
    """
        Function to move the aliens.
        This function will update the aliens dictionary directly.
        In:
            * aliens: The aliens list containing the parameters of the aliens (pos_x, pos_y, direction).
        Out:
            * None.
    """

    # If an alien reaches the border, all aliens goes down and changes direction
    if any([(alien["pos_x"] == 0 and alien["direction"] == "left") or (alien["pos_x"] == WIDTH - 1 and alien["direction"] == "right") for alien in aliens]):
        for alien in aliens:
            alien["pos_y"] += 1
            alien["direction"] = "left" if alien["direction"] == "right" else "right"

    # Otherwise, all aliens move in their direction
    else:
        for alien in aliens:
            alien["pos_x"] += 1 if alien["direction"] == "right" else -1



################################### NEW
def check_alien_shot ( cannon: Dict[str, Any],
                       aliens: List[Dict[str, Any]],
                       game:   Dict[str, int]
                     ) ->      None:
    
    """
        Function to check if an alien was shot by the cannon.
        This function will update the game dictionary directly.
        Any alien shot will be removed from the aliens list.
        The aliens list will be updated directly.
        In:
            * cannon: The cannon dictionary, containing the parameters of the player's cannon (pos_x, fire).
            * aliens: The aliens list containing the parameters of the aliens (pos_x, pos_y, direction).
            * game:   The game dictionary, containing the parameters of the game board (life, score, level).
        Out:
            * None.
    """

    # If the alien was shot, we remove it and increase the score
    if cannon["fire"]:
        for alien in aliens:
            if alien["pos_x"] == cannon["pos_x"] and alien["pos_y"] == max([al["pos_y"] for al in aliens if al["pos_x"] == cannon["pos_x"]]):
                game["score"] += 1
                aliens.remove(alien)
################################### /NEW



################################### NEW
def game_is_over ( cannon: Dict[str, Any],
                   aliens: List[Dict[str, Any]]
                 ) ->      Union[bool, Optional[str]]:
    
    """
        Function to check if the game is finished.
        The game is finished if all the aliens are dead or if they reach the cannon.
        In:
            * cannon: The cannon dictionary, containing the parameters of the player's cannon (pos_x, fire).
            * aliens: The aliens list containing the parameters of the aliens (pos_x, pos_y, direction).
        Out:
            * over:   A boolean indicating if the game is finished.
            * winner: A string indicating if the player or the aliens won.
    """

    # The game is finished if all the aliens are dead or if they reach the cannon
    all_aliens_dead = len(aliens) == 0
    player_dead = any([alien["pos_x"] == cannon["pos_x"] and alien["pos_y"] == HEIGHT - 1 for alien in aliens])

    # Return the result
    if all_aliens_dead:
        return True, "player"
    elif player_dead:
        return True, "aliens"
    else:
        return False, None
################################### /NEW



def start () -> None:

    """
        The main function of the game.
        In:
            * None.
        Out:
            * None.
    """

    # Initialize the game
    game = {"life": 1, "score": 0, "level": 1}
    cannon = {"pos_x": WIDTH // 2, "shoot": False}
    aliens = init_aliens()

    # To get the keyboard hit
    kb = keyboardcapture.KeyboardCapture()

    # Main loop of the game
    current_action = ""
    ################################### NEW
    over, winner = False, None
    while current_action != QUIT_KEY and not over:
    ################################### /NEW

        # Catch the keyboard hit in the list of possible keys
        current_action = kb.get_char([LEFT_KEY, RIGHT_KEY, SHOOT_KEY, QUIT_KEY])

        ################################### NEW
        # Check if the player shot an alien
        if current_action == SHOOT_KEY:
            cannon["fire"] = True
        else:
            cannon["fire"] = False
        ################################### /NEW

        # Move the cannon and the aliens
        if current_action in [LEFT_KEY, RIGHT_KEY]:
            move_cannon(cannon, current_action)
        move_aliens(aliens)

        # Display the current state of the game and sleep for a while
        display_game(game, cannon, aliens)
        sleep(SLEEP_TIME / game["level"])

        ################################### NEW
        # Check if an alien was shot and if the game is over
        check_alien_shot(cannon, aliens, game)
        over, winner = game_is_over(cannon, aliens)
        ################################### /NEW

    # Reset the terminal
    kb.set_normal_term()
   

    ################################### NEW
    # Display the final state and the winner
    display_game(game, cannon, aliens)
    if winner == "player":
        print("You won!")
    elif winner == "aliens":
        print("You lost!")
    ################################### /NEW



# When this script is run directly, execute the following code
if __name__ == "__main__":

    # Start the game
    start()
// Not available yet

12.5 — Pimp your game

Enjoy extending your game with the following features, for example:

  • Play until your life counter is 0.
  • Consider that special shots can fall from the sky and make your ship destroy a whole column of aliens.
  • Add special aliens that only die after receiving multiple shots.

To go beyond

Let’s come back to your linear regressor for these two additional exercices.

13 — Change the data generation for the test set

It would be interesting to evaluate the quality of your linear regressor when the test set has a different statistical distribution than the training set. Up to which difference is it robust?

This is a regular problem in machine learning, as available training data may sometimes be different from those we want to use the model for.

14 — Adapt the model

To solve the problem above, we can try to use a small available dataset with the same distribution as the test set (which is in general unknown when training the model).

Assume you have:

  • A training set of 800 points with a given slope.
  • A test set of 200 points with a (slightly) different slope than the training set.
  • An adatation set of 50 points with the same slope as the test set.

Train a linear regressor on the training set, and then perform a few epochs on the adatation set to adapt the regressor. What results do you get on the test set? Is it better than the non-adapted model?