2D Scatter Plot
2D Scatter plot is one of the simple and very useful plotting tool used in Exploratory Data Analysis. A 2D Scatter plot would take the data points in our dataset and would plot it on a chart. The position of a point depends on its two-dimensional value, where each value is a position on either the horizontal or vertical dimension. This plotting of points is called scatter plot because we are scattering all the points from our dataset on a map.
A 2D Scatter plot would primarily consist of three components:
- The scale of the plot
However, we can also add different color if we would like to classify or even add sizes to our markers to see some complexity in our data.
It is very important to have all the three basic components in order to have a proper understanding of the plot. After looking at the X and Y axis, care should be taken to understand the scale too because it might be a case that the point of origin might be (0,0). If we are including different colors for classification we should add labels to understand which color represent which label.
Scatter plots are used to understand the relationships between variables. A relationship such as whether there is any increase or decrease in the value of one variable because of the other.
Take the example of the result of Maths for a particular class. If we have the number of hours each student has studied on an average or cumulative on the y-axis and the marks they have received of x-axis then it will help us see the trend and the shift in the scores from the minimum to maximum score. Definitely, correlation does not lead to causation but with good domain knowledge, it will help us lead on the right path.
results.plot(kind='scatter', x='No. of hours studied', y='Marks received');
plt.xlabel('No. of hours studied daily')
plt.title('Maths Score and hours spent in studying daily')
A simple 2D Scatter plot will help us understand the endpoints and the concentration of the points on a map for the two axis under consideration. However, if we have the color label and would like to understand how it fares vis-à-vis these two points we can do a special color coding for them to understand whether there is a particular pattern which these class labels or dependent variables follow. In such a case, we should have labels to understand what color represents which label.
A very widely used example for this is the iris dataset where we can not only see the cluster of ‘Setosa’ species is separately clustered when we color code all the three species of flower.
iris.plot(kind='scatter', x='sepal_length', y='sepal_width'); plt.show()
sns.set_style("whitegrid"); sns.FacetGrid(iris, hue = "species", size = 4) \ .map(plt.scatter, "sepal_length", "sepal_width") \ .add_legend(); plt.show();
We can also move ahead and add one more dimension to our dataset which is the size. Apart from plotting our data points in the scatter plot and/or color coding them we can add size to our points. For example, let’s assume we are studying the noble prizes won by developed countries. We can have the country index in the x-axis and the number of noble prizes won in the y-axis. Here, we can also add the population dimension to the points and visualize the population of various countries and how many noble prizes they have won. Here, if we take out the overperforming country the United States then we can have a better picture of the whole set of countries. The United States not only has a huge population but it has received almost three times more prizes than the United Kingdom which comes in the second spot.
size = new size = 0.00001*size.astype(int) ax = plt.figure(figsize=(15,8.5)) plt.scatter(new, new, alpha=0.2, s=size, c = new, cmap='viridis') plt.xlabel('COUNTRY INDEX') plt.ylabel('NO. OF NOBLE PRIZES WON'); for i, txt in enumerate(new): plt.annotate(txt, (new[i],new[i])) plt.colorbar() plt.show();
Just for fun, I have also added one more dimension to the Iris Dataset to see how the ‘petal width’ size helps us differentiate species.
iris = np.array(iris) new = np.copy(iris) new[iris=='setosa'] = 1 new[iris=='versicolor'] = 2 new[iris=='virginica'] = 3 new = new.T size = 100*new size = size.astype(int) plt.figure(figsize=(10,6)) plt.scatter(new, new, alpha=0.8, s=size, c = new, cmap='viridis') plt.xlabel('sepal_length') plt.ylabel('sepal_width'); plt.colorbar() plt.show();
Scatter plots is a great visualization tool and an integral part of EDA. If you have any queries or feedback you can get in touch with me on email@example.com. The ipython notebook for the above example you can be found on my github profile.